diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index b515dbff9f3..af1538ad0c1 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -30,6 +30,7 @@ jobs: - wheel-tests-cudf - wheel-build-cudf-polars - wheel-tests-cudf-polars + - cudf-polars-polars-tests - wheel-build-dask-cudf - wheel-tests-dask-cudf - devcontainer @@ -244,6 +245,17 @@ jobs: # This always runs, but only fails if this PR touches code in # pylibcudf or cudf_polars script: "ci/test_wheel_cudf_polars.sh" + cudf-polars-polars-tests: + needs: wheel-build-cudf-polars + secrets: inherit + uses: rapidsai/shared-workflows/.github/workflows/wheels-test.yaml@branch-24.10 + with: + # This selects "ARCH=amd64 + the latest supported Python + CUDA". + matrix_filter: map(select(.ARCH == "amd64")) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))])) + build_type: pull-request + # This always runs, but only fails if this PR touches code in + # pylibcudf or cudf_polars + script: "ci/test_cudf_polars_polars_tests.sh" wheel-build-dask-cudf: needs: wheel-build-cudf secrets: inherit diff --git a/ci/run_cudf_polars_polars_tests.sh b/ci/run_cudf_polars_polars_tests.sh new file mode 100755 index 00000000000..52a827af94c --- /dev/null +++ b/ci/run_cudf_polars_polars_tests.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -euo pipefail + +# Support invoking run_cudf_polars_pytests.sh outside the script directory +# Assumption, polars has been cloned in the root of the repo. +cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../polars/ + +DESELECTED_TESTS=( + "tests/unit/test_polars_import.py::test_polars_import" # relies on a polars built in place + "tests/unit/streaming/test_streaming_sort.py::test_streaming_sort[True]" # relies on polars built in debug mode + "tests/unit/test_cpu_check.py::test_check_cpu_flags_skipped_no_flags" # Mock library error + "tests/docs/test_user_guide.py" # No dot binary in CI image +) + +DESELECTED_TESTS=$(printf -- " --deselect %s" "${DESELECTED_TESTS[@]}") +python -m pytest \ + --import-mode=importlib \ + --cache-clear \ + -m "" \ + -p cudf_polars.testing.plugin \ + -v \ + --tb=short \ + ${DESELECTED_TESTS} \ + "$@" \ + py-polars/tests diff --git a/ci/test_cudf_polars_polars_tests.sh b/ci/test_cudf_polars_polars_tests.sh new file mode 100755 index 00000000000..6c728a9537f --- /dev/null +++ b/ci/test_cudf_polars_polars_tests.sh @@ -0,0 +1,69 @@ +#!/bin/bash +# Copyright (c) 2024, NVIDIA CORPORATION. + +set -eou pipefail + +# We will only fail these tests if the PR touches code in pylibcudf +# or cudf_polars itself. +# Note, the three dots mean we are doing diff between the merge-base +# of upstream and HEAD. So this is asking, "does _this branch_ touch +# files in cudf_polars/pylibcudf", rather than "are there changes +# between upstream and this branch which touch cudf_polars/pylibcudf" +# TODO: is the target branch exposed anywhere in an environment variable? +if [ -n "$(git diff --name-only origin/branch-24.10...HEAD -- python/cudf_polars/ python/cudf/cudf/_lib/pylibcudf/)" ]; +then + HAS_CHANGES=1 + rapids-logger "PR has changes in cudf-polars/pylibcudf, test fails treated as failure" +else + HAS_CHANGES=0 + rapids-logger "PR does not have changes in cudf-polars/pylibcudf, test fails NOT treated as failure" +fi + +rapids-logger "Download wheels" + +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" +RAPIDS_PY_WHEEL_NAME="cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-s3 ./dist + +# Download the pylibcudf built in the previous step +RAPIDS_PY_WHEEL_NAME="pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibcudf-dep + +rapids-logger "Install pylibcudf" +python -m pip install ./local-pylibcudf-dep/pylibcudf*.whl + +rapids-logger "Install cudf_polars" +python -m pip install $(echo ./dist/cudf_polars*.whl) + +# TAG=$(python -c 'import polars; print(f"py-{polars.__version__}")') +TAG="py-1.7.0" +rapids-logger "Clone polars to ${TAG}" +git clone https://github.com/pola-rs/polars.git --branch ${TAG} --depth 1 + +# Install requirements for running polars tests +rapids-logger "Install polars test requirements" +python -m pip install -r polars/py-polars/requirements-dev.txt -r polars/py-polars/requirements-ci.txt + +function set_exitcode() +{ + EXITCODE=$? +} +EXITCODE=0 +trap set_exitcode ERR +set +e + +rapids-logger "Run polars tests" +./ci/run_cudf_polars_polars_tests.sh + +trap ERR +set -e + +if [ ${EXITCODE} != 0 ]; then + rapids-logger "Running polars test suite FAILED: exitcode ${EXITCODE}" +else + rapids-logger "Running polars test suite PASSED" +fi + +if [ ${HAS_CHANGES} == 1 ]; then + exit ${EXITCODE} +else + exit 0 +fi diff --git a/ci/test_wheel_cudf_polars.sh b/ci/test_wheel_cudf_polars.sh index 9844090258a..b4509bba02e 100755 --- a/ci/test_wheel_cudf_polars.sh +++ b/ci/test_wheel_cudf_polars.sh @@ -13,10 +13,14 @@ set -eou pipefail if [ -n "$(git diff --name-only origin/branch-24.10...HEAD -- python/cudf_polars/ python/pylibcudf/)" ]; then HAS_CHANGES=1 + rapids-logger "PR has changes in cudf-polars/pylibcudf, test fails treated as failure" else HAS_CHANGES=0 + rapids-logger "PR does not have changes in cudf-polars/pylibcudf, test fails NOT treated as failure" fi +rapids-logger "Download wheels" + RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" RAPIDS_PY_WHEEL_NAME="cudf_polars_${RAPIDS_PY_CUDA_SUFFIX}" RAPIDS_PY_WHEEL_PURE="1" rapids-download-wheels-from-s3 python ./dist @@ -43,6 +47,9 @@ python -m pip install \ "$(echo ./dist/libcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" \ "$(echo ./dist/pylibcudf_${RAPIDS_PY_CUDA_SUFFIX}*.whl)" +rapids-logger "Pin to 1.7.0 Temporarily" +python -m pip install polars==1.7.0 + rapids-logger "Run cudf_polars tests" function set_exitcode() diff --git a/dependencies.yaml b/dependencies.yaml index 7a13043cc5f..2f2d7ba679e 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -650,7 +650,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - polars>=1.0,<1.3 + - polars>=1.6 run_dask_cudf: common: - output_types: [conda, requirements, pyproject] diff --git a/docs/cudf/source/_static/Polars_GPU_speedup_80GB.png b/docs/cudf/source/_static/Polars_GPU_speedup_80GB.png new file mode 100644 index 00000000000..e472cf66612 Binary files /dev/null and b/docs/cudf/source/_static/Polars_GPU_speedup_80GB.png differ diff --git a/docs/cudf/source/_static/compute_heavy_queries_polars.png b/docs/cudf/source/_static/compute_heavy_queries_polars.png new file mode 100644 index 00000000000..6854ed5a436 Binary files /dev/null and b/docs/cudf/source/_static/compute_heavy_queries_polars.png differ diff --git a/docs/cudf/source/_static/pds_benchmark_polars.png b/docs/cudf/source/_static/pds_benchmark_polars.png new file mode 100644 index 00000000000..d0b48ab2901 Binary files /dev/null and b/docs/cudf/source/_static/pds_benchmark_polars.png differ diff --git a/docs/cudf/source/cudf_polars/index.rst b/docs/cudf/source/cudf_polars/index.rst new file mode 100644 index 00000000000..cc7aabd124f --- /dev/null +++ b/docs/cudf/source/cudf_polars/index.rst @@ -0,0 +1,41 @@ +cuDF-based GPU backend for Polars [Open Beta] +============================================= + +cuDF supports an in-memory, GPU-accelerated execution engine for Python users of the Polars Lazy API. +The engine supports most of the core expressions and data types as well as a growing set of more advanced dataframe manipulations +and data file formats. When using the GPU engine, Polars will convert expressions into an optimized query plan and determine +whether the plan is supported on the GPU. If it is not, the execution will transparently fall back to the standard Polars engine +and run on the CPU. + +Benchmark +--------- +We reproduced the `Polars Decision Support (PDS) `__ benchmark to compare Polars GPU engine with the default CPU settings across several dataset sizes. Here are the results: + +.. figure:: ../_static/pds_benchmark_polars.png + :width: 600px + + + +You can see up to 13x speedup using the GPU backend on the compute-heavy PDS queries involving complex aggregation and join operations. Below are the speedups for the top performing queries: + + +.. figure:: ../_static/compute_heavy_queries_polars.png + :width: 1000px + +:emphasis:`PDS-H benchmark | GPU: NVIDIA H100 PCIe | CPU: Intel Xeon W9-3495X (Sapphire Rapids) | Storage: Local NVMe` + +You can reproduce the results by visiting the `Polars Decision Support (PDS) GitHub repository `__. + +Learn More +---------- + +The GPU backend for Polars is now available in Open Beta and the engine is undergoing rapid development. To learn more, visit the `GPU Support page `__ on the Polars website. + +Launch on Google Colab +---------------------- + +.. figure:: ../_static/colab.png + :width: 200px + :target: https://colab.research.google.com/github/rapidsai-community/showcase/blob/main/accelerated_data_processing_examples/polars_gpu_engine_demo.ipynb + + Take the cuDF backend for Polars for a test-drive in a free GPU-enabled notebook environment using your Google account by `launching on Colab `__. diff --git a/docs/cudf/source/index.rst b/docs/cudf/source/index.rst index 3b8dfa5fe01..1b86cafeb48 100644 --- a/docs/cudf/source/index.rst +++ b/docs/cudf/source/index.rst @@ -29,5 +29,6 @@ other operations. user_guide/index cudf_pandas/index + cudf_polars/index libcudf_docs/index developer_guide/index diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst index 2518afc80a7..003e7c0c35e 100644 --- a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/index.rst @@ -14,3 +14,4 @@ strings repeat replace slice + strip diff --git a/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/strip.rst b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/strip.rst new file mode 100644 index 00000000000..a79774b8e67 --- /dev/null +++ b/docs/cudf/source/user_guide/api_docs/pylibcudf/strings/strip.rst @@ -0,0 +1,6 @@ +===== +strip +===== + +.. automodule:: pylibcudf.strings.strip + :members: diff --git a/python/cudf/cudf/_lib/datetime.pyx b/python/cudf/cudf/_lib/datetime.pyx index 483250dd36f..bc5e085ec39 100644 --- a/python/cudf/cudf/_lib/datetime.pyx +++ b/python/cudf/cudf/_lib/datetime.pyx @@ -17,6 +17,8 @@ from pylibcudf.libcudf.types cimport size_type from cudf._lib.column cimport Column from cudf._lib.scalar cimport DeviceScalar +import pylibcudf as plc + @acquire_spill_lock() def add_months(Column col, Column months): @@ -38,43 +40,9 @@ def add_months(Column col, Column months): @acquire_spill_lock() def extract_datetime_component(Column col, object field): - - cdef unique_ptr[column] c_result - cdef column_view col_view = col.view() - - with nogil: - if field == "year": - c_result = move(libcudf_datetime.extract_year(col_view)) - elif field == "month": - c_result = move(libcudf_datetime.extract_month(col_view)) - elif field == "day": - c_result = move(libcudf_datetime.extract_day(col_view)) - elif field == "weekday": - c_result = move(libcudf_datetime.extract_weekday(col_view)) - elif field == "hour": - c_result = move(libcudf_datetime.extract_hour(col_view)) - elif field == "minute": - c_result = move(libcudf_datetime.extract_minute(col_view)) - elif field == "second": - c_result = move(libcudf_datetime.extract_second(col_view)) - elif field == "millisecond": - c_result = move( - libcudf_datetime.extract_millisecond_fraction(col_view) - ) - elif field == "microsecond": - c_result = move( - libcudf_datetime.extract_microsecond_fraction(col_view) - ) - elif field == "nanosecond": - c_result = move( - libcudf_datetime.extract_nanosecond_fraction(col_view) - ) - elif field == "day_of_year": - c_result = move(libcudf_datetime.day_of_year(col_view)) - else: - raise ValueError(f"Invalid datetime field: '{field}'") - - result = Column.from_unique_ptr(move(c_result)) + result = Column.from_pylibcudf( + plc.datetime.extract_datetime_component(col.to_pylibcudf(mode="read"), field) + ) if field == "weekday": # Pandas counts Monday-Sunday as 0-6 diff --git a/python/cudf/cudf/_lib/string_casting.pyx b/python/cudf/cudf/_lib/string_casting.pyx index 8d463829a19..60a6795a402 100644 --- a/python/cudf/cudf/_lib/string_casting.pyx +++ b/python/cudf/cudf/_lib/string_casting.pyx @@ -20,13 +20,7 @@ from pylibcudf.libcudf.strings.convert.convert_booleans cimport ( to_booleans as cpp_to_booleans, ) from pylibcudf.libcudf.strings.convert.convert_datetime cimport ( - from_timestamps as cpp_from_timestamps, is_timestamp as cpp_is_timestamp, - to_timestamps as cpp_to_timestamps, -) -from pylibcudf.libcudf.strings.convert.convert_durations cimport ( - from_durations as cpp_from_durations, - to_durations as cpp_to_durations, ) from pylibcudf.libcudf.strings.convert.convert_floats cimport ( from_floats as cpp_from_floats, @@ -48,8 +42,12 @@ from pylibcudf.libcudf.types cimport data_type, type_id from cudf._lib.types cimport underlying_type_t_type_id +import pylibcudf as plc + import cudf +from cudf._lib.types cimport dtype_to_pylibcudf_type + def floating_to_string(Column input_col): cdef column_view input_column_view = input_col.view() @@ -522,19 +520,14 @@ def int2timestamp( A Column with date-time represented in string format """ - cdef column_view input_column_view = input_col.view() cdef string c_timestamp_format = format.encode("UTF-8") - cdef column_view input_strings_names = names.view() - - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_from_timestamps( - input_column_view, - c_timestamp_format, - input_strings_names)) - - return Column.from_unique_ptr(move(c_result)) + return Column.from_pylibcudf( + plc.strings.convert.convert_datetime.from_timestamps( + input_col.to_pylibcudf(mode="read"), + c_timestamp_format, + names.to_pylibcudf(mode="read") + ) + ) def timestamp2int(Column input_col, dtype, format): @@ -551,23 +544,15 @@ def timestamp2int(Column input_col, dtype, format): A Column with string represented in date-time format """ - cdef column_view input_column_view = input_col.view() - cdef type_id tid = ( - ( - SUPPORTED_NUMPY_TO_LIBCUDF_TYPES[dtype] + dtype = dtype_to_pylibcudf_type(dtype) + cdef string c_timestamp_format = format.encode('UTF-8') + return Column.from_pylibcudf( + plc.strings.convert.convert_datetime.to_timestamps( + input_col.to_pylibcudf(mode="read"), + dtype, + c_timestamp_format ) ) - cdef data_type out_type = data_type(tid) - cdef string c_timestamp_format = format.encode('UTF-8') - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_to_timestamps( - input_column_view, - out_type, - c_timestamp_format)) - - return Column.from_unique_ptr(move(c_result)) def istimestamp(Column input_col, str format): @@ -613,23 +598,15 @@ def timedelta2int(Column input_col, dtype, format): A Column with string represented in TimeDelta format """ - cdef column_view input_column_view = input_col.view() - cdef type_id tid = ( - ( - SUPPORTED_NUMPY_TO_LIBCUDF_TYPES[dtype] + dtype = dtype_to_pylibcudf_type(dtype) + cdef string c_timestamp_format = format.encode('UTF-8') + return Column.from_pylibcudf( + plc.strings.convert.convert_durations.to_durations( + input_col.to_pylibcudf(mode="read"), + dtype, + c_timestamp_format ) ) - cdef data_type out_type = data_type(tid) - cdef string c_duration_format = format.encode('UTF-8') - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_to_durations( - input_column_view, - out_type, - c_duration_format)) - - return Column.from_unique_ptr(move(c_result)) def int2timedelta(Column input_col, str format): @@ -647,16 +624,13 @@ def int2timedelta(Column input_col, str format): """ - cdef column_view input_column_view = input_col.view() cdef string c_duration_format = format.encode('UTF-8') - cdef unique_ptr[column] c_result - with nogil: - c_result = move( - cpp_from_durations( - input_column_view, - c_duration_format)) - - return Column.from_unique_ptr(move(c_result)) + return Column.from_pylibcudf( + plc.strings.convert.convert_durations.from_durations( + input_col.to_pylibcudf(mode="read"), + c_duration_format + ) + ) def int2ip(Column input_col): diff --git a/python/cudf/cudf/_lib/strings/strip.pyx b/python/cudf/cudf/_lib/strings/strip.pyx index acf52cb7b9f..38ecb21a94c 100644 --- a/python/cudf/cudf/_lib/strings/strip.pyx +++ b/python/cudf/cudf/_lib/strings/strip.pyx @@ -13,6 +13,7 @@ from pylibcudf.libcudf.strings.strip cimport strip as cpp_strip from cudf._lib.column cimport Column from cudf._lib.scalar cimport DeviceScalar +import pylibcudf as plc @acquire_spill_lock() @@ -25,23 +26,14 @@ def strip(Column source_strings, """ cdef DeviceScalar repl = py_repl.device_value - - cdef unique_ptr[column] c_result - cdef column_view source_view = source_strings.view() - - cdef const string_scalar* scalar_str = ( - repl.get_raw_ptr() + return Column.from_pylibcudf( + plc.strings.strip.strip( + source_strings.to_pylibcudf(mode="read"), + plc.strings.SideType.BOTH, + repl.c_value + ) ) - with nogil: - c_result = move(cpp_strip( - source_view, - side_type.BOTH, - scalar_str[0] - )) - - return Column.from_unique_ptr(move(c_result)) - @acquire_spill_lock() def lstrip(Column source_strings, diff --git a/python/cudf_polars/cudf_polars/__init__.py b/python/cudf_polars/cudf_polars/__init__.py index 41d06f8631b..c1317e8f467 100644 --- a/python/cudf_polars/cudf_polars/__init__.py +++ b/python/cudf_polars/cudf_polars/__init__.py @@ -10,10 +10,14 @@ from __future__ import annotations +# Check we have a supported polars version +import cudf_polars.utils.versions as v from cudf_polars._version import __git_commit__, __version__ from cudf_polars.callback import execute_with_cudf from cudf_polars.dsl.translate import translate_ir +del v + __all__: list[str] = [ "execute_with_cudf", "translate_ir", diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index f31193aa938..76816ee0a61 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -5,19 +5,26 @@ from __future__ import annotations +import contextlib import os import warnings -from functools import partial +from functools import cache, partial from typing import TYPE_CHECKING import nvtx -from polars.exceptions import PerformanceWarning +from polars.exceptions import ComputeError, PerformanceWarning + +import rmm +from rmm._cuda import gpu from cudf_polars.dsl.translate import translate_ir if TYPE_CHECKING: + from collections.abc import Generator + import polars as pl + from polars import GPUEngine from cudf_polars.dsl.ir import IR from cudf_polars.typing import NodeTraverser @@ -25,23 +32,126 @@ __all__: list[str] = ["execute_with_cudf"] +@cache +def default_memory_resource(device: int) -> rmm.mr.DeviceMemoryResource: + """ + Return the default memory resource for cudf-polars. + + Parameters + ---------- + device + Disambiguating device id when selecting the device. Must be + the active device when this function is called. + + Returns + ------- + rmm.mr.DeviceMemoryResource + The default memory resource that cudf-polars uses. Currently + an async pool resource. + """ + try: + return rmm.mr.CudaAsyncMemoryResource() + except RuntimeError as e: # pragma: no cover + msg, *_ = e.args + if ( + msg.startswith("RMM failure") + and msg.find("not supported with this CUDA driver/runtime version") > -1 + ): + raise ComputeError( + "GPU engine requested, but incorrect cudf-polars package installed. " + "If your system has a CUDA 11 driver, please uninstall `cudf-polars-cu12` " + "and install `cudf-polars-cu11`" + ) from None + else: + raise + + +@contextlib.contextmanager +def set_memory_resource( + mr: rmm.mr.DeviceMemoryResource | None, +) -> Generator[rmm.mr.DeviceMemoryResource, None, None]: + """ + Set the current memory resource for an execution block. + + Parameters + ---------- + mr + Memory resource to use. If `None`, calls :func:`default_memory_resource` + to obtain an mr on the currently active device. + + Returns + ------- + Memory resource used. + + Notes + ----- + At exit, the memory resource is restored to whatever was current + at entry. If a memory resource is provided, it must be valid to + use with the currently active device. + """ + if mr is None: + device: int = gpu.getDevice() + mr = default_memory_resource(device) + previous = rmm.mr.get_current_device_resource() + rmm.mr.set_current_device_resource(mr) + try: + yield mr + finally: + rmm.mr.set_current_device_resource(previous) + + +@contextlib.contextmanager +def set_device(device: int | None) -> Generator[int, None, None]: + """ + Set the device the query is executed on. + + Parameters + ---------- + device + Device to use. If `None`, uses the current device. + + Returns + ------- + Device active for the execution of the block. + + Notes + ----- + At exit, the device is restored to whatever was current at entry. + """ + previous: int = gpu.getDevice() + if device is not None: + gpu.setDevice(device) + try: + yield previous + finally: + gpu.setDevice(previous) + + def _callback( ir: IR, with_columns: list[str] | None, pyarrow_predicate: str | None, n_rows: int | None, + *, + device: int | None, + memory_resource: int | None, ) -> pl.DataFrame: assert with_columns is None assert pyarrow_predicate is None assert n_rows is None - with nvtx.annotate(message="ExecuteIR", domain="cudf_polars"): + with ( + nvtx.annotate(message="ExecuteIR", domain="cudf_polars"), + # Device must be set before memory resource is obtained. + set_device(device), + set_memory_resource(memory_resource), + ): return ir.evaluate(cache={}).to_polars() def execute_with_cudf( nt: NodeTraverser, *, - raise_on_fail: bool = False, + config: GPUEngine, exception: type[Exception] | tuple[type[Exception], ...] = Exception, ) -> None: """ @@ -52,9 +162,8 @@ def execute_with_cudf( nt NodeTraverser - raise_on_fail - Should conversion raise an exception rather than continuing - without setting a callback. + config + GPUEngine configuration object exception Optional exception, or tuple of exceptions, to catch during @@ -62,9 +171,23 @@ def execute_with_cudf( The NodeTraverser is mutated if the libcudf executor can handle the plan. """ + device = config.device + memory_resource = config.memory_resource + raise_on_fail = config.config.get("raise_on_fail", False) + if unsupported := (config.config.keys() - {"raise_on_fail"}): + raise ValueError( + f"Engine configuration contains unsupported settings {unsupported}" + ) try: with nvtx.annotate(message="ConvertIR", domain="cudf_polars"): - nt.set_udf(partial(_callback, translate_ir(nt))) + nt.set_udf( + partial( + _callback, + translate_ir(nt), + device=device, + memory_resource=memory_resource, + ) + ) except exception as e: if bool(int(os.environ.get("POLARS_VERBOSE", 0))): warnings.warn( diff --git a/python/cudf_polars/cudf_polars/containers/column.py b/python/cudf_polars/cudf_polars/containers/column.py index dd3b771e305..3fe3e5557cb 100644 --- a/python/cudf_polars/cudf_polars/containers/column.py +++ b/python/cudf_polars/cudf_polars/containers/column.py @@ -84,6 +84,34 @@ def sorted_like(self, like: Column, /) -> Self: is_sorted=like.is_sorted, order=like.order, null_order=like.null_order ) + # TODO: Return Column once #16272 is fixed. + def astype(self, dtype: plc.DataType) -> plc.Column: + """ + Return the backing column as the requested dtype. + + Parameters + ---------- + dtype + Datatype to cast to. + + Returns + ------- + Column of requested type. + + Raises + ------ + RuntimeError + If the cast is unsupported. + + Notes + ----- + This only produces a copy if the requested dtype doesn't match + the current one. + """ + if self.obj.type() != dtype: + return plc.unary.cast(self.obj, dtype) + return self.obj + def copy_metadata(self, from_: pl.Series, /) -> Self: """ Copy metadata from a host series onto self. diff --git a/python/cudf_polars/cudf_polars/containers/dataframe.py b/python/cudf_polars/cudf_polars/containers/dataframe.py index a5c99e2bc11..f3e3862d0cc 100644 --- a/python/cudf_polars/cudf_polars/containers/dataframe.py +++ b/python/cudf_polars/cudf_polars/containers/dataframe.py @@ -7,7 +7,7 @@ import itertools from functools import cached_property -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING import pyarrow as pa import pylibcudf as plc @@ -45,11 +45,19 @@ def copy(self) -> Self: def to_polars(self) -> pl.DataFrame: """Convert to a polars DataFrame.""" + # If the arrow table has empty names, from_arrow produces + # column_$i. But here we know there is only one such column + # (by construction) and it should have an empty name. + # https://github.com/pola-rs/polars/issues/11632 + # To guarantee we produce correct names, we therefore + # serialise with names we control and rename with that map. + name_map = {f"column_{i}": c.name for i, c in enumerate(self.columns)} table: pa.Table = plc.interop.to_arrow( self.table, - [plc.interop.ColumnMetadata(name=c.name) for c in self.columns], + [plc.interop.ColumnMetadata(name=name) for name in name_map], ) - return cast(pl.DataFrame, pl.from_arrow(table)).with_columns( + df: pl.DataFrame = pl.from_arrow(table) + return df.rename(name_map).with_columns( *( pl.col(c.name).set_sorted( descending=c.order == plc.types.Order.DESCENDING diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index e1b4d30b76b..c401e5a2f17 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -21,8 +21,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, NamedTuple import pyarrow as pa +import pyarrow.compute as pc import pylibcudf as plc +from polars.exceptions import InvalidOperationError from polars.polars import _expr_nodes as pl_expr from cudf_polars.containers import Column, NamedColumn @@ -477,12 +479,6 @@ def __init__( self.options = options self.name = name self.children = children - if ( - self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All) - and not self.options[0] - ): - # With ignore_nulls == False, polars uses Kleene logic - raise NotImplementedError(f"Kleene logic for {self.name}") if self.name == pl_expr.BooleanFunction.IsIn and not all( c.dtype == self.children[0].dtype for c in self.children ): @@ -577,20 +573,31 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == pl_expr.BooleanFunction.Any: + # Kleene logic for Any (OR) and All (AND) if ignore_nulls is + # False + if self.name in (pl_expr.BooleanFunction.Any, pl_expr.BooleanFunction.All): + (ignore_nulls,) = self.options (column,) = columns - return Column( - plc.Column.from_scalar( - plc.reduce.reduce(column.obj, plc.aggregation.any(), self.dtype), 1 - ) - ) - elif self.name == pl_expr.BooleanFunction.All: - (column,) = columns - return Column( - plc.Column.from_scalar( - plc.reduce.reduce(column.obj, plc.aggregation.all(), self.dtype), 1 - ) - ) + is_any = self.name == pl_expr.BooleanFunction.Any + agg = plc.aggregation.any() if is_any else plc.aggregation.all() + result = plc.reduce.reduce(column.obj, agg, self.dtype) + if not ignore_nulls and column.obj.null_count() > 0: + # Truth tables + # Any All + # | F U T | F U T + # --+------ --+------ + # F | F U T F | F F F + # U | U U T U | F U U + # T | T T T T | F U T + # + # If the input null count was non-zero, we must + # post-process the result to insert the correct value. + h_result = plc.interop.to_arrow(result).as_py() + if is_any and not h_result or not is_any and h_result: + # Any All + # False || Null => Null True && Null => Null + return Column(plc.Column.all_null_like(column.obj, 1)) + return Column(plc.Column.from_scalar(result, 1)) if self.name == pl_expr.BooleanFunction.IsNull: (column,) = columns return Column(plc.unary.is_null(column.obj)) @@ -598,13 +605,19 @@ def do_evaluate( (column,) = columns return Column(plc.unary.is_valid(column.obj)) elif self.name == pl_expr.BooleanFunction.IsNan: - # TODO: copy over null mask since is_nan(null) => null in polars (column,) = columns - return Column(plc.unary.is_nan(column.obj)) + return Column( + plc.unary.is_nan(column.obj).with_mask( + column.obj.null_mask(), column.obj.null_count() + ) + ) elif self.name == pl_expr.BooleanFunction.IsNotNan: - # TODO: copy over null mask since is_not_nan(null) => null in polars (column,) = columns - return Column(plc.unary.is_not_nan(column.obj)) + return Column( + plc.unary.is_not_nan(column.obj).with_mask( + column.obj.null_mask(), column.obj.null_count() + ) + ) elif self.name == pl_expr.BooleanFunction.IsFirstDistinct: (column,) = columns return self._distinct( @@ -654,26 +667,22 @@ def do_evaluate( ), ) elif self.name == pl_expr.BooleanFunction.AllHorizontal: - if any(c.obj.null_count() > 0 for c in columns): - raise NotImplementedError("Kleene logic for all_horizontal") return Column( reduce( partial( plc.binaryop.binary_operation, - op=plc.binaryop.BinaryOperator.BITWISE_AND, + op=plc.binaryop.BinaryOperator.NULL_LOGICAL_AND, output_type=self.dtype, ), (c.obj for c in columns), ) ) elif self.name == pl_expr.BooleanFunction.AnyHorizontal: - if any(c.obj.null_count() > 0 for c in columns): - raise NotImplementedError("Kleene logic for any_horizontal") return Column( reduce( partial( plc.binaryop.binary_operation, - op=plc.binaryop.BinaryOperator.BITWISE_OR, + op=plc.binaryop.BinaryOperator.NULL_LOGICAL_OR, output_type=self.dtype, ), (c.obj for c in columns), @@ -694,7 +703,7 @@ def do_evaluate( class StringFunction(Expr): - __slots__ = ("name", "options", "children") + __slots__ = ("name", "options", "children", "_regex_program") _non_child = ("dtype", "name", "options") children: tuple[Expr, ...] @@ -713,12 +722,18 @@ def __init__( def _validate_input(self): if self.name not in ( - pl_expr.StringFunction.Lowercase, - pl_expr.StringFunction.Uppercase, - pl_expr.StringFunction.EndsWith, - pl_expr.StringFunction.StartsWith, pl_expr.StringFunction.Contains, + pl_expr.StringFunction.EndsWith, + pl_expr.StringFunction.Lowercase, + pl_expr.StringFunction.Replace, + pl_expr.StringFunction.ReplaceMany, pl_expr.StringFunction.Slice, + pl_expr.StringFunction.Strptime, + pl_expr.StringFunction.StartsWith, + pl_expr.StringFunction.StripChars, + pl_expr.StringFunction.StripCharsStart, + pl_expr.StringFunction.StripCharsEnd, + pl_expr.StringFunction.Uppercase, ): raise NotImplementedError(f"String function {self.name}") if self.name == pl_expr.StringFunction.Contains: @@ -732,11 +747,65 @@ def _validate_input(self): raise NotImplementedError( "Regex contains only supports a scalar pattern" ) + pattern = self.children[1].value.as_py() + try: + self._regex_program = plc.strings.regex_program.RegexProgram.create( + pattern, + flags=plc.strings.regex_flags.RegexFlags.DEFAULT, + ) + except RuntimeError as e: + raise NotImplementedError( + f"Unsupported regex {pattern} for GPU engine." + ) from e + elif self.name == pl_expr.StringFunction.Replace: + _, literal = self.options + if not literal: + raise NotImplementedError("literal=False is not supported for replace") + if not all(isinstance(expr, Literal) for expr in self.children[1:]): + raise NotImplementedError("replace only supports scalar target") + target = self.children[1] + if target.value == pa.scalar("", type=pa.string()): + raise NotImplementedError( + "libcudf replace does not support empty strings" + ) + elif self.name == pl_expr.StringFunction.ReplaceMany: + (ascii_case_insensitive,) = self.options + if ascii_case_insensitive: + raise NotImplementedError( + "ascii_case_insensitive not implemented for replace_many" + ) + if not all( + isinstance(expr, (LiteralColumn, Literal)) for expr in self.children[1:] + ): + raise NotImplementedError("replace_many only supports literal inputs") + target = self.children[1] + if pc.any(pc.equal(target.value, "")).as_py(): + raise NotImplementedError( + "libcudf replace_many is implemented differently from polars " + "for empty strings" + ) elif self.name == pl_expr.StringFunction.Slice: if not all(isinstance(child, Literal) for child in self.children[1:]): raise NotImplementedError( "Slice only supports literal start and stop values" ) + elif self.name == pl_expr.StringFunction.Strptime: + format, _, exact, cache = self.options + if cache: + raise NotImplementedError("Strptime cache is a CPU feature") + if format is None: + raise NotImplementedError("Strptime format is required") + if not exact: + raise NotImplementedError("Strptime does not support exact=False") + elif self.name in { + pl_expr.StringFunction.StripChars, + pl_expr.StringFunction.StripCharsStart, + pl_expr.StringFunction.StripCharsEnd, + }: + if not isinstance(self.children[1], Literal): + raise NotImplementedError( + "strip operations only support scalar patterns" + ) def do_evaluate( self, @@ -759,12 +828,10 @@ def do_evaluate( else pat.obj ) return Column(plc.strings.find.contains(column.obj, pattern)) - assert isinstance(arg, Literal) - prog = plc.strings.regex_program.RegexProgram.create( - arg.value.as_py(), - flags=plc.strings.regex_flags.RegexFlags.DEFAULT, - ) - return Column(plc.strings.contains.contains_re(column.obj, prog)) + else: + return Column( + plc.strings.contains.contains_re(column.obj, self._regex_program) + ) elif self.name == pl_expr.StringFunction.Slice: child, expr_offset, expr_length = self.children assert isinstance(expr_offset, Literal) @@ -795,6 +862,22 @@ def do_evaluate( plc.interop.from_arrow(pa.scalar(stop, type=pa.int32())), ) ) + elif self.name in { + pl_expr.StringFunction.StripChars, + pl_expr.StringFunction.StripCharsStart, + pl_expr.StringFunction.StripCharsEnd, + }: + column, chars = ( + c.evaluate(df, context=context, mapping=mapping) for c in self.children + ) + if self.name == pl_expr.StringFunction.StripCharsStart: + side = plc.strings.SideType.LEFT + elif self.name == pl_expr.StringFunction.StripCharsEnd: + side = plc.strings.SideType.RIGHT + else: + side = plc.strings.SideType.BOTH + return Column(plc.strings.strip.strip(column.obj, side, chars.obj_scalar)) + columns = [ child.evaluate(df, context=context, mapping=mapping) for child in self.children @@ -825,6 +908,51 @@ def do_evaluate( else prefix.obj, ) ) + elif self.name == pl_expr.StringFunction.Strptime: + # TODO: ignores ambiguous + format, strict, exact, cache = self.options + col = self.children[0].evaluate(df, context=context, mapping=mapping) + + is_timestamps = plc.strings.convert.convert_datetime.is_timestamp( + col.obj, format.encode() + ) + + if strict: + if not plc.interop.to_arrow( + plc.reduce.reduce( + is_timestamps, + plc.aggregation.all(), + plc.DataType(plc.TypeId.BOOL8), + ) + ).as_py(): + raise InvalidOperationError("conversion from `str` failed.") + else: + not_timestamps = plc.unary.unary_operation( + is_timestamps, plc.unary.UnaryOperator.NOT + ) + + null = plc.interop.from_arrow(pa.scalar(None, type=pa.string())) + res = plc.copying.boolean_mask_scatter( + [null], plc.Table([col.obj]), not_timestamps + ) + return Column( + plc.strings.convert.convert_datetime.to_timestamps( + res.columns()[0], self.dtype, format.encode() + ) + ) + elif self.name == pl_expr.StringFunction.Replace: + column, target, repl = columns + n, _ = self.options + return Column( + plc.strings.replace.replace( + column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n + ) + ) + elif self.name == pl_expr.StringFunction.ReplaceMany: + column, target, repl = columns + return Column( + plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj) + ) raise NotImplementedError( f"StringFunction {self.name}" ) # pragma: no cover; handled by init raising @@ -832,6 +960,18 @@ def do_evaluate( class TemporalFunction(Expr): __slots__ = ("name", "options", "children") + _COMPONENT_MAP: ClassVar[dict[pl_expr.TemporalFunction, str]] = { + pl_expr.TemporalFunction.Year: "year", + pl_expr.TemporalFunction.Month: "month", + pl_expr.TemporalFunction.Day: "day", + pl_expr.TemporalFunction.WeekDay: "weekday", + pl_expr.TemporalFunction.Hour: "hour", + pl_expr.TemporalFunction.Minute: "minute", + pl_expr.TemporalFunction.Second: "second", + pl_expr.TemporalFunction.Millisecond: "millisecond", + pl_expr.TemporalFunction.Microsecond: "microsecond", + pl_expr.TemporalFunction.Nanosecond: "nanosecond", + } _non_child = ("dtype", "name", "options") children: tuple[Expr, ...] @@ -846,8 +986,8 @@ def __init__( self.options = options self.name = name self.children = children - if self.name != pl_expr.TemporalFunction.Year: - raise NotImplementedError(f"String function {self.name}") + if self.name not in self._COMPONENT_MAP: + raise NotImplementedError(f"Temporal function {self.name}") def do_evaluate( self, @@ -861,12 +1001,59 @@ def do_evaluate( child.evaluate(df, context=context, mapping=mapping) for child in self.children ] - if self.name == pl_expr.TemporalFunction.Year: - (column,) = columns - return Column(plc.datetime.extract_year(column.obj)) - raise NotImplementedError( - f"TemporalFunction {self.name}" - ) # pragma: no cover; init trips first + (column,) = columns + if self.name == pl_expr.TemporalFunction.Microsecond: + millis = plc.datetime.extract_datetime_component(column.obj, "millisecond") + micros = plc.datetime.extract_datetime_component(column.obj, "microsecond") + millis_as_micros = plc.binaryop.binary_operation( + millis, + plc.interop.from_arrow(pa.scalar(1_000, type=pa.int32())), + plc.binaryop.BinaryOperator.MUL, + plc.DataType(plc.TypeId.INT32), + ) + total_micros = plc.binaryop.binary_operation( + micros, + millis_as_micros, + plc.binaryop.BinaryOperator.ADD, + plc.types.DataType(plc.types.TypeId.INT32), + ) + return Column(total_micros) + elif self.name == pl_expr.TemporalFunction.Nanosecond: + millis = plc.datetime.extract_datetime_component(column.obj, "millisecond") + micros = plc.datetime.extract_datetime_component(column.obj, "microsecond") + nanos = plc.datetime.extract_datetime_component(column.obj, "nanosecond") + millis_as_nanos = plc.binaryop.binary_operation( + millis, + plc.interop.from_arrow(pa.scalar(1_000_000, type=pa.int32())), + plc.binaryop.BinaryOperator.MUL, + plc.types.DataType(plc.types.TypeId.INT32), + ) + micros_as_nanos = plc.binaryop.binary_operation( + micros, + plc.interop.from_arrow(pa.scalar(1_000, type=pa.int32())), + plc.binaryop.BinaryOperator.MUL, + plc.types.DataType(plc.types.TypeId.INT32), + ) + total_nanos = plc.binaryop.binary_operation( + nanos, + millis_as_nanos, + plc.binaryop.BinaryOperator.ADD, + plc.types.DataType(plc.types.TypeId.INT32), + ) + total_nanos = plc.binaryop.binary_operation( + total_nanos, + micros_as_nanos, + plc.binaryop.BinaryOperator.ADD, + plc.types.DataType(plc.types.TypeId.INT32), + ) + return Column(total_nanos) + + return Column( + plc.datetime.extract_datetime_component( + column.obj, + self._COMPONENT_MAP[self.name], + ) + ) class UnaryFunction(Expr): @@ -874,6 +1061,51 @@ class UnaryFunction(Expr): _non_child = ("dtype", "name", "options") children: tuple[Expr, ...] + # Note: log, and pow are handled via translation to binops + _OP_MAPPING: ClassVar[dict[str, plc.unary.UnaryOperator]] = { + "sin": plc.unary.UnaryOperator.SIN, + "cos": plc.unary.UnaryOperator.COS, + "tan": plc.unary.UnaryOperator.TAN, + "arcsin": plc.unary.UnaryOperator.ARCSIN, + "arccos": plc.unary.UnaryOperator.ARCCOS, + "arctan": plc.unary.UnaryOperator.ARCTAN, + "sinh": plc.unary.UnaryOperator.SINH, + "cosh": plc.unary.UnaryOperator.COSH, + "tanh": plc.unary.UnaryOperator.TANH, + "arcsinh": plc.unary.UnaryOperator.ARCSINH, + "arccosh": plc.unary.UnaryOperator.ARCCOSH, + "arctanh": plc.unary.UnaryOperator.ARCTANH, + "exp": plc.unary.UnaryOperator.EXP, + "sqrt": plc.unary.UnaryOperator.SQRT, + "cbrt": plc.unary.UnaryOperator.CBRT, + "ceil": plc.unary.UnaryOperator.CEIL, + "floor": plc.unary.UnaryOperator.FLOOR, + "abs": plc.unary.UnaryOperator.ABS, + "bit_invert": plc.unary.UnaryOperator.BIT_INVERT, + "not": plc.unary.UnaryOperator.NOT, + } + _supported_misc_fns = frozenset( + { + "drop_nulls", + "fill_null", + "mask_nans", + "round", + "set_sorted", + "unique", + } + ) + _supported_cum_aggs = frozenset( + { + "cum_min", + "cum_max", + "cum_prod", + "cum_sum", + } + ) + _supported_fns = frozenset().union( + _supported_misc_fns, _supported_cum_aggs, _OP_MAPPING.keys() + ) + def __init__( self, dtype: plc.DataType, name: str, options: tuple[Any, ...], *children: Expr ) -> None: @@ -881,15 +1113,15 @@ def __init__( self.name = name self.options = options self.children = children - if self.name not in ( - "mask_nans", - "round", - "setsorted", - "unique", - "dropnull", - "fill_null", - ): + + if self.name not in UnaryFunction._supported_fns: raise NotImplementedError(f"Unary function {name=}") + if self.name in UnaryFunction._supported_cum_aggs: + (reverse,) = self.options + if reverse: + raise NotImplementedError( + "reverse=True is not supported for cumulative aggregations" + ) def do_evaluate( self, @@ -947,7 +1179,7 @@ def do_evaluate( if maintain_order: return Column(column).sorted_like(values) return Column(column) - elif self.name == "setsorted": + elif self.name == "set_sorted": (column,) = ( child.evaluate(df, context=context, mapping=mapping) for child in self.children @@ -974,7 +1206,7 @@ def do_evaluate( order=order, null_order=null_order, ) - elif self.name == "dropnull": + elif self.name == "drop_nulls": (column,) = ( child.evaluate(df, context=context, mapping=mapping) for child in self.children @@ -994,13 +1226,65 @@ def do_evaluate( ) arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj return Column(plc.replace.replace_nulls(column.obj, arg)) - + elif self.name in self._OP_MAPPING: + column = self.children[0].evaluate(df, context=context, mapping=mapping) + if column.obj.type().id() != self.dtype.id(): + arg = plc.unary.cast(column.obj, self.dtype) + else: + arg = column.obj + return Column(plc.unary.unary_operation(arg, self._OP_MAPPING[self.name])) + elif self.name in UnaryFunction._supported_cum_aggs: + column = self.children[0].evaluate(df, context=context, mapping=mapping) + plc_col = column.obj + col_type = column.obj.type() + # cum_sum casts + # Int8, UInt8, Int16, UInt16 -> Int64 for overflow prevention + # Bool -> UInt32 + # cum_prod casts integer dtypes < int64 and bool to int64 + # See: + # https://github.com/pola-rs/polars/blob/main/crates/polars-ops/src/series/ops/cum_agg.rs + if ( + self.name == "cum_sum" + and col_type.id() + in { + plc.types.TypeId.INT8, + plc.types.TypeId.UINT8, + plc.types.TypeId.INT16, + plc.types.TypeId.UINT16, + } + ) or ( + self.name == "cum_prod" + and plc.traits.is_integral(col_type) + and plc.types.size_of(col_type) <= 4 + ): + plc_col = plc.unary.cast( + plc_col, plc.types.DataType(plc.types.TypeId.INT64) + ) + elif ( + self.name == "cum_sum" + and column.obj.type().id() == plc.types.TypeId.BOOL8 + ): + plc_col = plc.unary.cast( + plc_col, plc.types.DataType(plc.types.TypeId.UINT32) + ) + if self.name == "cum_sum": + agg = plc.aggregation.sum() + elif self.name == "cum_prod": + agg = plc.aggregation.product() + elif self.name == "cum_min": + agg = plc.aggregation.min() + elif self.name == "cum_max": + agg = plc.aggregation.max() + + return Column(plc.reduce.scan(plc_col, agg, plc.reduce.ScanType.INCLUSIVE)) raise NotImplementedError( f"Unimplemented unary function {self.name=}" ) # pragma: no cover; init trips first def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" + if self.name in {"unique", "drop_nulls"} | self._supported_cum_aggs: + raise NotImplementedError(f"{self.name} in groupby") if depth == 1: # inside aggregation, need to pre-evaluate, groupby # construction has checked that we don't have nested aggs, @@ -1187,11 +1471,7 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) - if not ( - plc.traits.is_fixed_width(self.dtype) - and plc.traits.is_fixed_width(value.dtype) - and plc.unary.is_supported_cast(value.dtype, self.dtype) - ): + if not dtypes.can_cast(value.dtype, self.dtype): raise NotImplementedError( f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" ) @@ -1255,6 +1535,13 @@ def __init__( req = plc.aggregation.variance(ddof=options) elif name == "count": req = plc.aggregation.count(null_handling=plc.types.NullPolicy.EXCLUDE) + elif name == "quantile": + _, quantile = self.children + if not isinstance(quantile, Literal): + raise NotImplementedError("Only support literal quantile values") + req = plc.aggregation.quantile( + quantiles=[quantile.value.as_py()], interp=Agg.interp_mapping[options] + ) else: raise NotImplementedError( f"Unreachable, {name=} is incorrectly listed in _SUPPORTED" @@ -1286,9 +1573,18 @@ def __init__( "count", "std", "var", + "quantile", ] ) + interp_mapping: ClassVar[dict[str, plc.types.Interpolation]] = { + "nearest": plc.types.Interpolation.NEAREST, + "higher": plc.types.Interpolation.HIGHER, + "lower": plc.types.Interpolation.LOWER, + "midpoint": plc.types.Interpolation.MIDPOINT, + "linear": plc.types.Interpolation.LINEAR, + } + def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" if depth >= 1: @@ -1299,7 +1595,19 @@ def collect_agg(self, *, depth: int) -> AggInfo: raise NotImplementedError("Nan propagation in groupby for min/max") (child,) = self.children ((expr, _, _),) = child.collect_agg(depth=depth + 1).requests - if self.request is None: + request = self.request + # These are handled specially here because we don't set up the + # request for the whole-frame agg because we can avoid a + # reduce for these. + if self.name == "first": + request = plc.aggregation.nth_element( + 0, null_handling=plc.types.NullPolicy.INCLUDE + ) + elif self.name == "last": + request = plc.aggregation.nth_element( + -1, null_handling=plc.types.NullPolicy.INCLUDE + ) + if request is None: raise NotImplementedError( f"Aggregation {self.name} in groupby" ) # pragma: no cover; __init__ trips first @@ -1308,7 +1616,7 @@ def collect_agg(self, *, depth: int) -> AggInfo: # Ignore nans in these groupby aggs, do this by masking # nans in the input expr = UnaryFunction(self.dtype, "mask_nans", (), expr) - return AggInfo([(expr, self.request, self)]) + return AggInfo([(expr, request, self)]) def _reduce( self, column: Column, *, request: plc.aggregation.Aggregation @@ -1380,7 +1688,10 @@ def do_evaluate( raise NotImplementedError( f"Agg in context {context}" ) # pragma: no cover; unreachable - (child,) = self.children + + # Aggregations like quantiles may have additional children that were + # preprocessed into pylibcudf requests. + child = self.children[0] return self.op(child.evaluate(df, context=context, mapping=mapping)) @@ -1425,6 +1736,11 @@ def __init__( right: Expr, ) -> None: super().__init__(dtype) + if plc.traits.is_boolean(self.dtype): + # For boolean output types, bitand and bitor implement + # boolean logic, so translate. bitxor also does, but the + # default behaviour is correct. + op = BinOp._BOOL_KLEENE_MAPPING.get(op, op) self.op = op self.children = (left, right) if not plc.binaryop.is_supported_operation( @@ -1436,6 +1752,15 @@ def __init__( f"with output type {self.dtype.id().name}" ) + _BOOL_KLEENE_MAPPING: ClassVar[ + dict[plc.binaryop.BinaryOperator, plc.binaryop.BinaryOperator] + ] = { + plc.binaryop.BinaryOperator.BITWISE_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND, + plc.binaryop.BinaryOperator.BITWISE_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR, + plc.binaryop.BinaryOperator.LOGICAL_AND: plc.binaryop.BinaryOperator.NULL_LOGICAL_AND, + plc.binaryop.BinaryOperator.LOGICAL_OR: plc.binaryop.BinaryOperator.NULL_LOGICAL_OR, + } + _MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = { pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL, pl_expr.Operator.EqValidity: plc.binaryop.BinaryOperator.NULL_EQUALS, diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index e334e6f5cc5..8cd56c8ee3a 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -15,7 +15,6 @@ import dataclasses import itertools -import types from functools import cache from pathlib import Path from typing import TYPE_CHECKING, Any, ClassVar @@ -28,7 +27,7 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import DataFrame, NamedColumn -from cudf_polars.utils import sorting +from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: from collections.abc import Callable, MutableMapping @@ -133,8 +132,7 @@ class IR: def __post_init__(self): """Validate preconditions.""" - if any(dtype.id() == plc.TypeId.EMPTY for dtype in self.schema.values()): - raise NotImplementedError("Cannot make empty columns.") + pass # noqa: PIE790 def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """ @@ -189,32 +187,42 @@ class Scan(IR): """Cloud-related authentication options, currently ignored.""" paths: list[str] """List of paths to read from.""" - file_options: Any - """Options for reading the file. - - Attributes are: - - ``with_columns: list[str]`` of projected columns to return. - - ``n_rows: int``: Number of rows to read. - - ``row_index: tuple[name, offset] | None``: Add an integer index - column with given name. - """ + with_columns: list[str] + """Projected columns to return.""" + skip_rows: int + """Rows to skip at the start when reading.""" + n_rows: int + """Number of rows to read after skipping.""" + row_index: tuple[str, int] | None + """If not None add an integer index column of the given name.""" predicate: expr.NamedExpr | None """Mask to apply to the read dataframe.""" def __post_init__(self) -> None: """Validate preconditions.""" + super().__post_init__() if self.typ not in ("csv", "parquet", "ndjson"): # pragma: no cover # This line is unhittable ATM since IPC/Anonymous scan raise # on the polars side raise NotImplementedError(f"Unhandled scan type: {self.typ}") - if self.typ == "ndjson" and self.file_options.n_rows is not None: - raise NotImplementedError("row limit in scan") + if self.typ == "ndjson" and (self.n_rows != -1 or self.skip_rows != 0): + raise NotImplementedError("row limit in scan for json reader") + if self.skip_rows < 0: + # TODO: polars has this implemented for parquet, + # maybe we can do this too? + raise NotImplementedError("slice pushdown for negative slices") + if self.typ == "csv" and self.skip_rows != 0: # pragma: no cover + # This comes from slice pushdown, but that + # optimization doesn't happen right now + raise NotImplementedError("skipping rows in CSV reader") if self.cloud_options is not None and any( self.cloud_options.get(k) is not None for k in ("aws", "azure", "gcp") ): raise NotImplementedError( "Read from cloud storage" ) # pragma: no cover; no test yet + if any(p.startswith("https://") for p in self.paths): + raise NotImplementedError("Read from https") if self.typ == "csv": if self.reader_options["skip_rows_after_header"] != 0: raise NotImplementedError("Skipping rows after header in CSV reader") @@ -242,13 +250,21 @@ def __post_init__(self) -> None: raise NotImplementedError( "ignore_errors is not supported in the JSON reader" ) + elif ( + self.typ == "parquet" + and self.row_index is not None + and self.with_columns is not None + and len(self.with_columns) == 0 + ): + raise NotImplementedError( + "Reading only parquet metadata to produce row index." + ) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - options = self.file_options - with_columns = options.with_columns - row_index = options.row_index - nrows = self.file_options.n_rows if self.file_options.n_rows is not None else -1 + with_columns = self.with_columns + row_index = self.row_index + n_rows = self.n_rows if self.typ == "csv": parse_options = self.reader_options["parse_options"] sep = chr(parse_options["separator"]) @@ -256,7 +272,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: eol = chr(parse_options["eol_char"]) if self.reader_options["schema"] is not None: # Reader schema provides names - column_names = list(self.reader_options["schema"]["inner"].keys()) + column_names = list(self.reader_options["schema"]["fields"].keys()) else: # file provides column names column_names = None @@ -282,6 +298,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # polars skips blank lines at the beginning of the file pieces = [] + read_partial = n_rows != -1 for p in self.paths: skiprows = self.reader_options["skip_rows"] path = Path(p) @@ -303,9 +320,13 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: comment=comment, decimal=decimal, dtypes=self.schema, - nrows=nrows, + nrows=n_rows, ) pieces.append(tbl_w_meta) + if read_partial: + n_rows -= tbl_w_meta.tbl.num_rows() + if n_rows <= 0: + break tables, colnames = zip( *( (piece.tbl, piece.column_names(include_children=False)) @@ -321,7 +342,8 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: tbl_w_meta = plc.io.parquet.read_parquet( plc.io.SourceInfo(self.paths), columns=with_columns, - nrows=nrows, + nrows=n_rows, + skip_rows=self.skip_rows, ) df = DataFrame.from_table( tbl_w_meta.tbl, @@ -354,12 +376,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: raise NotImplementedError( f"Unhandled scan type: {self.typ}" ) # pragma: no cover; post init trips first - if ( - row_index is not None - # TODO: remove condition when dropping support for polars 1.0 - # https://github.com/pola-rs/polars/pull/17363 - and row_index[0] in self.schema - ): + if row_index is not None: name, offset = row_index dtype = self.schema[name] step = plc.interop.from_arrow( @@ -481,36 +498,6 @@ def evaluate( return DataFrame(columns) -def placeholder_column(n: int) -> plc.Column: - """ - Produce a placeholder pylibcudf column with NO BACKING DATA. - - Parameters - ---------- - n - Number of rows the column will advertise - - Returns - ------- - pylibcudf Column that is almost unusable. DO NOT ACCESS THE DATA BUFFER. - - Notes - ----- - This is used to avoid allocating data for count aggregations. - """ - return plc.Column( - plc.DataType(plc.TypeId.INT8), - n, - plc.gpumemoryview( - types.SimpleNamespace(__cuda_array_interface__={"data": (1, True)}) - ), - None, - 0, - 0, - [], - ) - - @dataclasses.dataclass class GroupBy(IR): """Perform a groupby.""" @@ -557,8 +544,7 @@ def check_agg(agg: expr.Expr) -> int: def __post_init__(self) -> None: """Check whether all the aggregations are implemented.""" - if self.options.rolling is None and self.maintain_order: - raise NotImplementedError("Maintaining order in groupby") + super().__post_init__() if self.options.rolling: raise NotImplementedError( "rolling window/groupby" @@ -566,6 +552,8 @@ def __post_init__(self) -> None: if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests): raise NotImplementedError("Nested aggregations in groupby") self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests] + if len(self.keys) == 0: + raise NotImplementedError("dynamic groupby") def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" @@ -591,7 +579,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: for info in self.agg_infos: for pre_eval, req, rep in info.requests: if pre_eval is None: - col = placeholder_column(df.num_rows) + # A count aggregation, doesn't touch the column, + # but we need to have one. Rather than evaluating + # one, just use one of the key columns. + col = keys[0].obj else: col = pre_eval.evaluate(df).obj requests.append(plc.groupby.GroupByRequest(col, [req])) @@ -611,7 +602,34 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: results = [ req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests ] - return DataFrame(broadcast(*result_keys, *results)).slice(self.options.slice) + broadcasted = broadcast(*result_keys, *results) + result_keys = broadcasted[: len(result_keys)] + results = broadcasted[len(result_keys) :] + # Handle order preservation of groups + # like cudf classic does + # https://github.com/rapidsai/cudf/blob/5780c4d8fb5afac2e04988a2ff5531f94c22d3a3/python/cudf/cudf/core/groupby/groupby.py#L723-L743 + if self.maintain_order and not sorted: + left = plc.stream_compaction.stable_distinct( + plc.Table([k.obj for k in keys]), + list(range(group_keys.num_columns())), + plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST, + plc.types.NullEquality.EQUAL, + plc.types.NanEquality.ALL_EQUAL, + ) + right = plc.Table([key.obj for key in result_keys]) + _, indices = plc.join.left_join(left, right, plc.types.NullEquality.EQUAL) + ordered_table = plc.copying.gather( + plc.Table([col.obj for col in broadcasted]), + indices, + plc.copying.OutOfBoundsPolicy.DONT_CHECK, + ) + broadcasted = [ + NamedColumn(reordered, b.name) + for reordered, b in zip( + ordered_table.columns(), broadcasted, strict=True + ) + ] + return DataFrame(broadcasted).slice(self.options.slice) @dataclasses.dataclass @@ -627,7 +645,7 @@ class Join(IR): right_on: list[expr.NamedExpr] """List of expressions used as keys in the right frame.""" options: tuple[ - Literal["inner", "left", "full", "leftsemi", "leftanti", "cross"], + Literal["inner", "left", "right", "full", "leftsemi", "leftanti", "cross"], bool, tuple[int, int] | None, str | None, @@ -644,6 +662,7 @@ class Join(IR): def __post_init__(self) -> None: """Validate preconditions.""" + super().__post_init__() if any( isinstance(e.value, expr.Literal) for e in itertools.chain(self.left_on, self.right_on) @@ -653,7 +672,7 @@ def __post_init__(self) -> None: @staticmethod @cache def _joiners( - how: Literal["inner", "left", "full", "leftsemi", "leftanti"], + how: Literal["inner", "left", "right", "full", "leftsemi", "leftanti"], ) -> tuple[ Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None ]: @@ -663,7 +682,7 @@ def _joiners( plc.copying.OutOfBoundsPolicy.DONT_CHECK, plc.copying.OutOfBoundsPolicy.DONT_CHECK, ) - elif how == "left": + elif how == "left" or how == "right": return ( plc.join.left_join, plc.copying.OutOfBoundsPolicy.DONT_CHECK, @@ -687,8 +706,7 @@ def _joiners( plc.copying.OutOfBoundsPolicy.DONT_CHECK, None, ) - else: - assert_never(how) + assert_never(how) def _reorder_maps( self, @@ -786,8 +804,12 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: table = plc.copying.gather(left.table, lg, left_policy) result = DataFrame.from_table(table, left.column_names) else: + if how == "right": + # Right join is a left join with the tables swapped + left, right = right, left + left_on, right_on = right_on, left_on lg, rg = join_fn(left_on.table, right_on.table, null_equality) - if how == "left": + if how == "left" or how == "right": # Order of left table is preserved lg, rg = self._reorder_maps( left.num_rows, lg, left_policy, right.num_rows, rg, right_policy @@ -815,6 +837,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: ) ) right = right.discard_columns(right_on.column_names_set) + if how == "right": + # Undo the swap for right join before gluing together. + left, right = right, left right = right.rename_columns( { name: f"{name}{suffix}" @@ -1065,11 +1090,13 @@ class MapFunction(IR): # "merge_sorted", "rename", "explode", + "unpivot", ] ) def __post_init__(self) -> None: """Validate preconditions.""" + super().__post_init__() if self.name not in MapFunction._NAMES: raise NotImplementedError(f"Unhandled map function {self.name}") if self.name == "explode": @@ -1086,6 +1113,22 @@ def __post_init__(self) -> None: set(new) & (set(self.df.schema.keys() - set(old))) ): raise NotImplementedError("Duplicate new names in rename.") + elif self.name == "unpivot": + indices, pivotees, variable_name, value_name = self.options + value_name = "value" if value_name is None else value_name + variable_name = "variable" if variable_name is None else variable_name + if len(pivotees) == 0: + index = frozenset(indices) + pivotees = [name for name in self.df.schema if name not in index] + if not all( + dtypes.can_cast(self.df.schema[p], self.schema[value_name]) + for p in pivotees + ): + raise NotImplementedError( + "Unpivot cannot cast all input columns to " + f"{self.schema[value_name].id()}" + ) + self.options = (indices, pivotees, variable_name, value_name) def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" @@ -1107,6 +1150,41 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return DataFrame.from_table( plc.lists.explode_outer(df.table, index), df.column_names ).sorted_like(df, subset=subset) + elif self.name == "unpivot": + indices, pivotees, variable_name, value_name = self.options + npiv = len(pivotees) + df = self.df.evaluate(cache=cache) + index_columns = [ + NamedColumn(col, name) + for col, name in zip( + plc.reshape.tile(df.select(indices).table, npiv).columns(), + indices, + strict=True, + ) + ] + (variable_column,) = plc.filling.repeat( + plc.Table( + [ + plc.interop.from_arrow( + pa.array( + pivotees, + type=plc.interop.to_arrow(self.schema[variable_name]), + ), + ) + ] + ), + df.num_rows, + ).columns() + value_column = plc.concatenate.concatenate( + [c.astype(self.schema[value_name]) for c in df.select(pivotees).columns] + ) + return DataFrame( + [ + *index_columns, + NamedColumn(variable_column, variable_name), + NamedColumn(value_column, value_name), + ] + ) else: raise AssertionError("Should never be reached") # pragma: no cover @@ -1122,6 +1200,7 @@ class Union(IR): def __post_init__(self) -> None: """Validate preconditions.""" + super().__post_init__() schema = self.dfs[0].schema if not all(s.schema == schema for s in self.dfs[1:]): raise NotImplementedError("Schema mismatch") diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 6dc97c7cb51..45881afe0c8 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -75,13 +75,12 @@ def _translate_ir( def _( node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: - return ir.PythonScan( - schema, - node.options, - translate_named_expr(visitor, n=node.predicate) - if node.predicate is not None - else None, + scan_fn, with_columns, source_type, predicate, nrows = node.options + options = (scan_fn, with_columns, source_type, nrows) + predicate = ( + translate_named_expr(visitor, n=predicate) if predicate is not None else None ) + return ir.PythonScan(schema, options, predicate) @_translate_ir.register @@ -94,13 +93,35 @@ def _( cloud_options = None else: reader_options, cloud_options = map(json.loads, options) + if ( + typ == "csv" + and visitor.version()[0] == 1 + and reader_options["schema"] is not None + ): + reader_options["schema"] = { + "fields": reader_options["schema"]["inner"] + } # pragma: no cover; CI tests 1.7 + file_options = node.file_options + with_columns = file_options.with_columns + n_rows = file_options.n_rows + if n_rows is None: + n_rows = -1 # All rows + skip_rows = 0 # Don't skip + else: + # TODO: with versioning, rename on the rust side + skip_rows, n_rows = n_rows + + row_index = file_options.row_index return ir.Scan( schema, typ, reader_options, cloud_options, node.paths, - node.file_options, + with_columns, + skip_rows, + n_rows, + row_index, translate_named_expr(visitor, n=node.predicate) if node.predicate is not None else None, @@ -293,10 +314,28 @@ def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR: ctx: AbstractContextManager[None] = ( set_node(visitor, n) if n is not None else noop_context ) + # IR is versioned with major.minor, minor is bumped for backwards + # compatible changes (e.g. adding new nodes), major is bumped for + # incompatible changes (e.g. renaming nodes). + # Polars 1.7 changes definition of the CSV reader options schema name. + if (version := visitor.version()) >= (3, 0): + raise NotImplementedError( + f"No support for polars IR {version=}" + ) # pragma: no cover; no such version for now. + with ctx: + polars_schema = visitor.get_schema() node = visitor.view_current_node() - schema = {k: dtypes.from_polars(v) for k, v in visitor.get_schema().items()} - return _translate_ir(node, visitor, schema) + schema = {k: dtypes.from_polars(v) for k, v in polars_schema.items()} + result = _translate_ir(node, visitor, schema) + if any( + isinstance(dtype, pl.Null) + for dtype in pl.datatypes.unpack_dtypes(*polars_schema.values()) + ): + raise NotImplementedError( + f"No GPU support for {result} with Null column dtype." + ) + return result def translate_named_expr( @@ -345,6 +384,24 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex name, *options = node.function_data options = tuple(options) if isinstance(name, pl_expr.StringFunction): + if name in { + pl_expr.StringFunction.StripChars, + pl_expr.StringFunction.StripCharsStart, + pl_expr.StringFunction.StripCharsEnd, + }: + column, chars = (translate_expr(visitor, n=n) for n in node.input) + if isinstance(chars, expr.Literal): + if chars.value == pa.scalar(""): + # No-op in polars, but libcudf uses empty string + # as signifier to remove whitespace. + return column + elif chars.value == pa.scalar(None): + # Polars uses None to mean "strip all whitespace" + chars = expr.Literal( + column.dtype, + pa.scalar("", type=plc.interop.to_arrow(column.dtype)), + ) + return expr.StringFunction(dtype, name, options, column, chars) return expr.StringFunction( dtype, name, @@ -369,19 +426,43 @@ def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> ex *(translate_expr(visitor, n=n) for n in node.input), ) elif isinstance(name, pl_expr.TemporalFunction): - return expr.TemporalFunction( + # functions for which evaluation of the expression may not return + # the same dtype as polars, either due to libcudf returning a different + # dtype, or due to our internal processing affecting what libcudf returns + needs_cast = { + pl_expr.TemporalFunction.Year, + pl_expr.TemporalFunction.Month, + pl_expr.TemporalFunction.Day, + pl_expr.TemporalFunction.WeekDay, + pl_expr.TemporalFunction.Hour, + pl_expr.TemporalFunction.Minute, + pl_expr.TemporalFunction.Second, + pl_expr.TemporalFunction.Millisecond, + } + result_expr = expr.TemporalFunction( dtype, name, options, *(translate_expr(visitor, n=n) for n in node.input), ) + if name in needs_cast: + return expr.Cast(dtype, result_expr) + return result_expr + elif isinstance(name, str): - return expr.UnaryFunction( - dtype, - name, - options, - *(translate_expr(visitor, n=n) for n in node.input), - ) + children = (translate_expr(visitor, n=n) for n in node.input) + if name == "log": + (base,) = options + (child,) = children + return expr.BinOp( + dtype, + plc.binaryop.BinaryOperator.LOG_BASE, + child, + expr.Literal(dtype, pa.scalar(base, type=plc.interop.to_arrow(dtype))), + ) + elif name == "pow": + return expr.BinOp(dtype, plc.binaryop.BinaryOperator.POW, *children) + return expr.UnaryFunction(dtype, name, options, *children) raise NotImplementedError( f"No handler for Expr function node with {name=}" ) # pragma: no cover; polars raises on the rust side for now diff --git a/python/cudf_polars/cudf_polars/testing/asserts.py b/python/cudf_polars/cudf_polars/testing/asserts.py index d37c96a15de..a79d45899cd 100644 --- a/python/cudf_polars/cudf_polars/testing/asserts.py +++ b/python/cudf_polars/cudf_polars/testing/asserts.py @@ -5,12 +5,11 @@ from __future__ import annotations -from functools import partial from typing import TYPE_CHECKING +from polars import GPUEngine from polars.testing.asserts import assert_frame_equal -from cudf_polars.callback import execute_with_cudf from cudf_polars.dsl.translate import translate_ir if TYPE_CHECKING: @@ -77,21 +76,13 @@ def assert_gpu_result_equal( NotImplementedError If GPU collection failed in some way. """ - if collect_kwargs is None: - collect_kwargs = {} - final_polars_collect_kwargs = collect_kwargs.copy() - final_cudf_collect_kwargs = collect_kwargs.copy() - if polars_collect_kwargs is not None: - final_polars_collect_kwargs.update(polars_collect_kwargs) - if cudf_collect_kwargs is not None: # pragma: no cover - # exclude from coverage since not used ATM - # but this is probably still useful - final_cudf_collect_kwargs.update(cudf_collect_kwargs) - expect = lazydf.collect(**final_polars_collect_kwargs) - got = lazydf.collect( - **final_cudf_collect_kwargs, - post_opt_callback=partial(execute_with_cudf, raise_on_fail=True), + final_polars_collect_kwargs, final_cudf_collect_kwargs = _process_kwargs( + collect_kwargs, polars_collect_kwargs, cudf_collect_kwargs ) + + expect = lazydf.collect(**final_polars_collect_kwargs) + engine = GPUEngine(raise_on_fail=True) + got = lazydf.collect(**final_cudf_collect_kwargs, engine=engine) assert_frame_equal( expect, got, @@ -134,3 +125,94 @@ def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) raise AssertionError(f"Translation DID NOT RAISE {exceptions}") from e else: raise AssertionError(f"Translation DID NOT RAISE {exceptions}") + + +def _process_kwargs( + collect_kwargs: dict[OptimizationArgs, bool] | None, + polars_collect_kwargs: dict[OptimizationArgs, bool] | None, + cudf_collect_kwargs: dict[OptimizationArgs, bool] | None, +) -> tuple[dict[OptimizationArgs, bool], dict[OptimizationArgs, bool]]: + if collect_kwargs is None: + collect_kwargs = {} + final_polars_collect_kwargs = collect_kwargs.copy() + final_cudf_collect_kwargs = collect_kwargs.copy() + if polars_collect_kwargs is not None: # pragma: no cover; not currently used + final_polars_collect_kwargs.update(polars_collect_kwargs) + if cudf_collect_kwargs is not None: # pragma: no cover; not currently used + final_cudf_collect_kwargs.update(cudf_collect_kwargs) + return final_polars_collect_kwargs, final_cudf_collect_kwargs + + +def assert_collect_raises( + lazydf: pl.LazyFrame, + *, + polars_except: type[Exception] | tuple[type[Exception], ...], + cudf_except: type[Exception] | tuple[type[Exception], ...], + collect_kwargs: dict[OptimizationArgs, bool] | None = None, + polars_collect_kwargs: dict[OptimizationArgs, bool] | None = None, + cudf_collect_kwargs: dict[OptimizationArgs, bool] | None = None, +): + """ + Assert that collecting the result of a query raises the expected exceptions. + + Parameters + ---------- + lazydf + frame to collect. + collect_kwargs + Common keyword arguments to pass to collect for both polars CPU and + cudf-polars. + Useful for controlling optimization settings. + polars_except + Exception or exceptions polars CPU is expected to raise. + cudf_except + Exception or exceptions polars GPU is expected to raise. + collect_kwargs + Common keyword arguments to pass to collect for both polars CPU and + cudf-polars. + Useful for controlling optimization settings. + polars_collect_kwargs + Keyword arguments to pass to collect for execution on polars CPU. + Overrides kwargs in collect_kwargs. + Useful for controlling optimization settings. + cudf_collect_kwargs + Keyword arguments to pass to collect for execution on cudf-polars. + Overrides kwargs in collect_kwargs. + Useful for controlling optimization settings. + + Returns + ------- + None + If both sides raise the expected exceptions. + + Raises + ------ + AssertionError + If either side did not raise the expected exceptions. + """ + final_polars_collect_kwargs, final_cudf_collect_kwargs = _process_kwargs( + collect_kwargs, polars_collect_kwargs, cudf_collect_kwargs + ) + + try: + lazydf.collect(**final_polars_collect_kwargs) + except polars_except: + pass + except Exception as e: + raise AssertionError( + f"CPU execution RAISED {type(e)}, EXPECTED {polars_except}" + ) from e + else: + raise AssertionError(f"CPU execution DID NOT RAISE {polars_except}") + + engine = GPUEngine(raise_on_fail=True) + try: + lazydf.collect(**final_cudf_collect_kwargs, engine=engine) + except cudf_except: + pass + except Exception as e: + raise AssertionError( + f"GPU execution RAISED {type(e)}, EXPECTED {polars_except}" + ) from e + else: + raise AssertionError(f"GPU execution DID NOT RAISE {polars_except}") diff --git a/python/cudf_polars/cudf_polars/testing/plugin.py b/python/cudf_polars/cudf_polars/testing/plugin.py new file mode 100644 index 00000000000..c40d59e6d33 --- /dev/null +++ b/python/cudf_polars/cudf_polars/testing/plugin.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 + +"""Plugin for running polars test suite setting GPU engine as default.""" + +from __future__ import annotations + +from functools import partialmethod +from typing import TYPE_CHECKING + +import pytest + +import polars + +if TYPE_CHECKING: + from collections.abc import Mapping + + +def pytest_addoption(parser: pytest.Parser): + """Add plugin-specific options.""" + group = parser.getgroup( + "cudf-polars", "Plugin to set GPU as default engine for polars tests" + ) + group.addoption( + "--cudf-polars-no-fallback", + action="store_true", + help="Turn off fallback to CPU when running tests (default use fallback)", + ) + + +def pytest_configure(config: pytest.Config): + """Enable use of this module as a pytest plugin to enable GPU collection.""" + no_fallback = config.getoption("--cudf-polars-no-fallback") + collect = polars.LazyFrame.collect + engine = polars.GPUEngine(raise_on_fail=no_fallback) + polars.LazyFrame.collect = partialmethod(collect, engine=engine) + config.addinivalue_line( + "filterwarnings", + "ignore:.*GPU engine does not support streaming or background collection", + ) + config.addinivalue_line( + "filterwarnings", + "ignore:.*Query execution with GPU not supported", + ) + + +EXPECTED_FAILURES: Mapping[str, str] = { + "tests/unit/io/test_csv.py::test_compressed_csv": "Need to determine if file is compressed", + "tests/unit/io/test_csv.py::test_read_csv_only_loads_selected_columns": "Memory usage won't be correct due to GPU", + "tests/unit/io/test_lazy_count_star.py::test_count_compressed_csv_18057": "Need to determine if file is compressed", + "tests/unit/io/test_lazy_csv.py::test_scan_csv_slice_offset_zero": "Integer overflow in sliced read", + "tests/unit/io/test_lazy_parquet.py::test_parquet_is_in_statistics": "Debug output on stderr doesn't match", + "tests/unit/io/test_lazy_parquet.py::test_parquet_statistics": "Debug output on stderr doesn't match", + "tests/unit/io/test_lazy_parquet.py::test_parquet_different_schema[False]": "Needs cudf#16394", + "tests/unit/io/test_lazy_parquet.py::test_parquet_schema_mismatch_panic_17067[False]": "Needs cudf#16394", + "tests/unit/io/test_lazy_parquet.py::test_parquet_slice_pushdown_non_zero_offset[False]": "Thrift data not handled correctly/slice pushdown wrong?", + "tests/unit/io/test_parquet.py::test_read_parquet_only_loads_selected_columns_15098": "Memory usage won't be correct due to GPU", + "tests/unit/io/test_scan.py::test_scan[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[single-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[glob-csv-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_projected_out[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_filter_and_limit[glob-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_filter_and_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_limit_and_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_and_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_limit_and_filter[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_projected_out[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_with_row_index_filter_and_limit[single-parquet-async]": "Debug output on stderr doesn't match", + "tests/unit/io/test_scan.py::test_scan_include_file_name[False-scan_parquet-write_parquet]": "Need to add include_file_path to IR", + "tests/unit/io/test_scan.py::test_scan_include_file_name[False-scan_csv-write_csv]": "Need to add include_file_path to IR", + "tests/unit/io/test_scan.py::test_scan_include_file_name[False-scan_ndjson-write_ndjson]": "Need to add include_file_path to IR", + "tests/unit/lazyframe/test_engine_selection.py::test_engine_import_error_raises[gpu]": "Expect this to pass because cudf-polars is installed", + "tests/unit/lazyframe/test_engine_selection.py::test_engine_import_error_raises[engine1]": "Expect this to pass because cudf-polars is installed", + "tests/unit/lazyframe/test_lazyframe.py::test_round[dtype1-123.55-1-123.6]": "Rounding midpoints is handled incorrectly", + "tests/unit/lazyframe/test_lazyframe.py::test_cast_frame": "Casting that raises not supported on GPU", + "tests/unit/lazyframe/test_lazyframe.py::test_lazy_cache_hit": "Debug output on stderr doesn't match", + "tests/unit/operations/aggregation/test_aggregations.py::test_duration_function_literal": "Broadcasting inside groupby-agg not supported", + "tests/unit/operations/aggregation/test_aggregations.py::test_sum_empty_and_null_set": "libcudf sums column of all nulls to null, not zero", + "tests/unit/operations/aggregation/test_aggregations.py::test_binary_op_agg_context_no_simplify_expr_12423": "groupby-agg of just literals should not produce collect_list", + "tests/unit/operations/aggregation/test_aggregations.py::test_nan_inf_aggregation": "treatment of nans and nulls together is different in libcudf and polars in groupby-agg context", + "tests/unit/operations/test_abs.py::test_abs_duration": "Need to raise for unsupported uops on timelike values", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input7-expected7-Float32-Float32]": "Mismatching dtypes, needs cudf#15852", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input10-expected10-Date-output_dtype10]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input11-expected11-input_dtype11-output_dtype11]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input12-expected12-input_dtype12-output_dtype12]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_mean_by_dtype[input13-expected13-input_dtype13-output_dtype13]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input7-expected7-Float32-Float32]": "Mismatching dtypes, needs cudf#15852", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input10-expected10-Date-output_dtype10]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input11-expected11-input_dtype11-output_dtype11]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input12-expected12-input_dtype12-output_dtype12]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input13-expected13-input_dtype13-output_dtype13]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input14-expected14-input_dtype14-output_dtype14]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input15-expected15-input_dtype15-output_dtype15]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_median_by_dtype[input16-expected16-input_dtype16-output_dtype16]": "Unsupported groupby-agg for a particular dtype", + "tests/unit/operations/test_group_by.py::test_group_by_binary_agg_with_literal": "Incorrect broadcasting of literals in groupby-agg", + "tests/unit/operations/test_group_by.py::test_aggregated_scalar_elementwise_15602": "Unsupported boolean function/dtype combination in groupby-agg", + "tests/unit/operations/test_group_by.py::test_schemas[data1-expr1-expected_select1-expected_gb1]": "Mismatching dtypes, needs cudf#15852", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_by_monday_and_offset_5444": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_label[left-expected0]": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_label[right-expected1]": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_label[datapoint-expected2]": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_rolling_dynamic_sortedness_check": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_validation": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_group_by_dynamic.py::test_group_by_dynamic_15225": "IR needs to expose groupby-dynamic information", + "tests/unit/operations/test_join.py::test_cross_join_slice_pushdown": "Need to implement slice pushdown for cross joins", + "tests/unit/sql/test_cast.py::test_cast_errors[values0-values::uint8-conversion from `f64` to `u64` failed]": "Casting that raises not supported on GPU", + "tests/unit/sql/test_cast.py::test_cast_errors[values1-values::uint4-conversion from `i64` to `u32` failed]": "Casting that raises not supported on GPU", + "tests/unit/sql/test_cast.py::test_cast_errors[values2-values::int1-conversion from `i64` to `i8` failed]": "Casting that raises not supported on GPU", + "tests/unit/sql/test_miscellaneous.py::test_read_csv": "Incorrect handling of missing_is_null in read_csv", + "tests/unit/sql/test_wildcard_opts.py::test_select_wildcard_errors": "Raises correctly but with different exception", + "tests/unit/streaming/test_streaming_io.py::test_parquet_eq_statistics": "Debug output on stderr doesn't match", + "tests/unit/test_cse.py::test_cse_predicate_self_join": "Debug output on stderr doesn't match", + "tests/unit/test_empty.py::test_empty_9137": "Mismatching dtypes, needs cudf#15852", + # Maybe flaky, order-dependent? + "tests/unit/test_projections.py::test_schema_full_outer_join_projection_pd_13287": "Order-specific result check, query is correct but in different order", + "tests/unit/test_queries.py::test_group_by_agg_equals_zero_3535": "libcudf sums all nulls to null, not zero", +} + + +def pytest_collection_modifyitems( + session: pytest.Session, config: pytest.Config, items: list[pytest.Item] +): + """Mark known failing tests.""" + if config.getoption("--cudf-polars-no-fallback"): + # Don't xfail tests if running without fallback + return + for item in items: + if item.nodeid in EXPECTED_FAILURES: + item.add_marker(pytest.mark.xfail(reason=EXPECTED_FAILURES[item.nodeid])) diff --git a/python/cudf_polars/cudf_polars/typing/__init__.py b/python/cudf_polars/cudf_polars/typing/__init__.py index adab10bdded..240b11bdf59 100644 --- a/python/cudf_polars/cudf_polars/typing/__init__.py +++ b/python/cudf_polars/cudf_polars/typing/__init__.py @@ -84,6 +84,10 @@ def view_expression(self, n: int) -> Expr: """Convert the given expression to python rep.""" ... + def version(self) -> tuple[int, int]: + """The IR version as `(major, minor)`.""" + ... + def set_udf( self, callback: Callable[[list[str] | None, str | None, int | None], pl.DataFrame], diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 7f6ea1edfd9..4154a404e98 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -13,7 +13,7 @@ import polars as pl -__all__ = ["from_polars", "downcast_arrow_lists"] +__all__ = ["from_polars", "downcast_arrow_lists", "can_cast"] def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: @@ -45,6 +45,28 @@ def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: return typ +def can_cast(from_: plc.DataType, to: plc.DataType) -> bool: + """ + Can we cast (via :func:`~.pylibcudf.unary.cast`) between two datatypes. + + Parameters + ---------- + from_ + Source datatype + to + Target datatype + + Returns + ------- + True if casting is supported, False otherwise + """ + return ( + plc.traits.is_fixed_width(to) + and plc.traits.is_fixed_width(from_) + and plc.unary.is_supported_cast(from_, to) + ) + + @cache def from_polars(dtype: pl.DataType) -> plc.DataType: """ diff --git a/python/cudf_polars/cudf_polars/utils/versions.py b/python/cudf_polars/cudf_polars/utils/versions.py index 9807cffb384..2e6efde968c 100644 --- a/python/cudf_polars/cudf_polars/utils/versions.py +++ b/python/cudf_polars/cudf_polars/utils/versions.py @@ -12,18 +12,11 @@ POLARS_VERSION = parse(__version__) -POLARS_VERSION_GE_10 = POLARS_VERSION >= parse("1.0") -POLARS_VERSION_GE_11 = POLARS_VERSION >= parse("1.1") -POLARS_VERSION_GE_12 = POLARS_VERSION >= parse("1.2") -POLARS_VERSION_GE_121 = POLARS_VERSION >= parse("1.2.1") -POLARS_VERSION_GT_10 = POLARS_VERSION > parse("1.0") -POLARS_VERSION_GT_11 = POLARS_VERSION > parse("1.1") -POLARS_VERSION_GT_12 = POLARS_VERSION > parse("1.2") - -POLARS_VERSION_LE_12 = POLARS_VERSION <= parse("1.2") -POLARS_VERSION_LE_11 = POLARS_VERSION <= parse("1.1") -POLARS_VERSION_LT_12 = POLARS_VERSION < parse("1.2") -POLARS_VERSION_LT_11 = POLARS_VERSION < parse("1.1") - -if POLARS_VERSION < parse("1.0"): # pragma: no cover - raise ImportError("cudf_polars requires py-polars v1.0 or greater.") +POLARS_VERSION_GE_16 = POLARS_VERSION >= parse("1.6") +POLARS_VERSION_GT_16 = POLARS_VERSION > parse("1.6") +POLARS_VERSION_LT_16 = POLARS_VERSION < parse("1.6") + +if POLARS_VERSION_LT_16: + raise ImportError( + "cudf_polars requires py-polars v1.6 or greater." + ) # pragma: no cover diff --git a/python/cudf_polars/docs/overview.md b/python/cudf_polars/docs/overview.md index 6cd36136bf8..103ac1a674e 100644 --- a/python/cudf_polars/docs/overview.md +++ b/python/cudf_polars/docs/overview.md @@ -15,8 +15,10 @@ You will need: ## Installing polars -We will need to build polars from source. Until things settle down, -live at `HEAD`. +`cudf-polars` works with polars >= 1.3, as long as the internal IR +version doesn't get a major version bump. So `pip install polars>=1.3` +should work. For development, if we're adding things to the polars +side of things, we will need to build polars from source: ```sh git clone https://github.com/pola-rs/polars @@ -59,7 +61,7 @@ The executor for the polars logical plan lives in the cudf repo, in ```sh cd cudf/python/cudf_polars -uv pip install --no-build-isolation --no-deps -e . +pip install --no-build-isolation --no-deps -e . ``` You should now be able to run the tests in the `cudf_polars` package: @@ -69,16 +71,18 @@ pytest -v tests # Executor design -The polars `LazyFrame.collect` functionality offers a -"post-optimization" callback that may be used by a third party library -to replace a node (or more, though we only replace a single node) in the -optimized logical plan with a Python callback that is to deliver the -result of evaluating the plan. This splits the execution of the plan -into two phases. First, a symbolic phase which translates to our -internal representation (IR). Second, an execution phase which executes -using our IR. - -The translation phase receives the a low-level Rust `NodeTraverse` +The polars `LazyFrame.collect` functionality offers configuration of +the engine to use for collection through the `engine` argument. At a +low level, this provides for configuration of a "post-optimization" +callback that may be used by a third party library to replace a node +(or more, though we only replace a single node) in the optimized +logical plan with a Python callback that is to deliver the result of +evaluating the plan. This splits the execution of the plan into two +phases. First, a symbolic phase which translates to our internal +representation (IR). Second, an execution phase which executes using +our IR. + +The translation phase receives the a low-level Rust `NodeTraverser` object which delivers Python representations of the plan nodes (and expressions) one at a time. During translation, we endeavour to raise `NotImplementedError` for any unsupported functionality. This way, if @@ -86,33 +90,60 @@ we can't execute something, we just don't modify the logical plan at all: if we can translate the IR, it is assumed that evaluation will later succeed. -The usage of the cudf-based executor is therefore, at present: +The usage of the cudf-based executor is therefore selected with the +gpu engine: ```python -from cudf_polars.callback import execute_with_cudf +import polars as pl -result = q.collect(post_opt_callback=execute_with_cudf) +result = q.collect(engine="gpu") ``` This should either transparently run on the GPU and deliver a polars dataframe, or else fail (but be handled) and just run the normal CPU -execution. +execution. If `POLARS_VERBOSE` is true, then fallback is logged with a +`PerformanceWarning`. -If you want to fail during translation, set the keyword argument -`raise_on_fail` to `True`: +As well as a string argument, the engine can also be specified with a +polars `GPUEngine` object. This allows passing more configuration in. +Currently, the public properties are `device`, to select the device, +and `memory_resource`, to select the RMM memory resource used for +allocations during the collection phase. +For example: ```python -from functools import partial -from cudf_polars.callback import execute_with_cudf +import polars as pl -result = q.collect( - post_opt_callback=partial(execute_with_cudf, raise_on_fail=True) -) +result = q.collect(engine=pl.GPUEngine(device=1, memory_resource=mr)) +``` + +Uses device-1, and the given memory resource. Note that the memory +resource provided _must_ be valid for allocations on the specified +device, no checking is performed. + +For debugging purposes, we can also pass undocumented keyword +arguments, at the moment, `raise_on_fail` is also supported, which +raises, rather than falling back, during translation: + +```python + +result = q.collect(engine=pl.GPUEngine(raise_on_fail=True)) ``` This is mostly useful when writing tests, since in that case we want any failures to propagate, rather than falling back to the CPU mode. +## IR versioning + +On the polars side, the `NodeTraverser` object advertises an internal +version (via `NodeTraverser.version()` as a `(major, minor)` tuple). +`minor` version bumps are for backwards compatible changes (e.g. +exposing new nodes), whereas `major` bumps are for incompatible +changes. We can therefore attempt to detect the IR version +(independently of the polars version) and dispatch, or error +appropriately. This should be done during IR translation in +`translate.py`. + ## Adding a handler for a new plan node Plan node definitions live in `cudf_polars/dsl/ir.py`, these are @@ -175,7 +206,7 @@ around their pylibcudf counterparts. We have four (in 1. `Scalar` (a wrapper around a pylibcudf `Scalar`) 2. `Column` (a wrapper around a pylibcudf `Column`) -3. `NamedColumn` a `Column` with an additional name +3. `NamedColumn` (a `Column` with an additional name) 4. `DataFrame` (a wrapper around a pylibcudf `Table`) The interfaces offered by these are somewhat in flux, but broadly diff --git a/python/cudf_polars/pyproject.toml b/python/cudf_polars/pyproject.toml index 984b5487b98..857a8c14b2f 100644 --- a/python/cudf_polars/pyproject.toml +++ b/python/cudf_polars/pyproject.toml @@ -19,7 +19,7 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.10" dependencies = [ - "polars>=1.0,<1.3", + "polars>=1.6", "pylibcudf==24.10.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ @@ -58,6 +58,9 @@ exclude_also = [ "class .*\\bProtocol\\):", "assert_never\\(" ] +# The cudf_polars test suite doesn't exercise the plugin, so we omit +# it from coverage checks. +omit = ["cudf_polars/testing/plugin.py"] [tool.ruff] line-length = 88 diff --git a/python/cudf_polars/tests/containers/test_dataframe.py b/python/cudf_polars/tests/containers/test_dataframe.py index 6b470268084..39fb44d55a5 100644 --- a/python/cudf_polars/tests/containers/test_dataframe.py +++ b/python/cudf_polars/tests/containers/test_dataframe.py @@ -9,6 +9,7 @@ import polars as pl from cudf_polars.containers import DataFrame, NamedColumn +from cudf_polars.testing.asserts import assert_gpu_result_equal def test_select_missing_raises(): @@ -140,3 +141,13 @@ def test_sorted_flags_preserved(with_nulls, nulls_last): assert b.null_order == b_null_order assert c.is_sorted == plc.types.Sorted.NO assert df.flags == gf.to_polars().flags + + +def test_empty_name_roundtrips_overlap(): + df = pl.LazyFrame({"": [1, 2, 3], "column_0": [4, 5, 6]}) + assert_gpu_result_equal(df) + + +def test_empty_name_roundtrips_no_overlap(): + df = pl.LazyFrame({"": [1, 2, 3], "b": [4, 5, 6]}) + assert_gpu_result_equal(df) diff --git a/python/cudf_polars/tests/expressions/test_agg.py b/python/cudf_polars/tests/expressions/test_agg.py index 245bde3acab..56055f4c6c2 100644 --- a/python/cudf_polars/tests/expressions/test_agg.py +++ b/python/cudf_polars/tests/expressions/test_agg.py @@ -7,15 +7,38 @@ import polars as pl from cudf_polars.dsl import expr -from cudf_polars.testing.asserts import assert_gpu_result_equal +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) -@pytest.fixture(params=sorted(expr.Agg._SUPPORTED)) +@pytest.fixture( + params=[ + # regular aggs from Agg + "min", + "max", + "median", + "n_unique", + "first", + "last", + "mean", + "sum", + "count", + "std", + "var", + # scan aggs from UnaryFunction + "cum_min", + "cum_max", + "cum_prod", + "cum_sum", + ] +) def agg(request): return request.param -@pytest.fixture(params=[pl.Int32, pl.Float32, pl.Int16]) +@pytest.fixture(params=[pl.Int32, pl.Float32, pl.Int16, pl.Int8, pl.UInt16]) def dtype(request): return request.param @@ -34,6 +57,11 @@ def df(dtype, with_nulls, is_sorted): if is_sorted: values = sorted(values, key=lambda x: -1000 if x is None else x) + if dtype.is_unsigned_integer(): + values = pl.Series(values).abs() + if is_sorted: + values = values.sort() + df = pl.LazyFrame({"a": values}, schema={"a": dtype}) if is_sorted: return df.set_sorted("a") @@ -52,6 +80,51 @@ def test_agg(df, agg): assert_gpu_result_equal(q, check_dtypes=check_dtypes, check_exact=False) +def test_bool_agg(agg, request): + if agg == "cum_min" or agg == "cum_max": + pytest.skip("Does not apply") + request.applymarker( + pytest.mark.xfail( + condition=agg == "n_unique", + reason="Wrong dtype we get Int32, polars gets UInt32", + ) + ) + df = pl.LazyFrame({"a": [True, False, None, True]}) + expr = getattr(pl.col("a"), agg)() + q = df.select(expr) + + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("cum_agg", expr.UnaryFunction._supported_cum_aggs) +def test_cum_agg_reverse_unsupported(cum_agg): + df = pl.LazyFrame({"a": [1, 2, 3]}) + expr = getattr(pl.col("a"), cum_agg)(reverse=True) + q = df.select(expr) + + assert_ir_translation_raises(q, NotImplementedError) + + +@pytest.mark.parametrize("q", [0.5, pl.lit(0.5)]) +@pytest.mark.parametrize("interp", ["nearest", "higher", "lower", "midpoint", "linear"]) +def test_quantile(df, q, interp): + expr = pl.col("a").quantile(q, interp) + q = df.select(expr) + + # https://github.com/rapidsai/cudf/issues/15852 + check_dtypes = q.collect_schema()["a"] == pl.Float64 + if not check_dtypes: + with pytest.raises(AssertionError): + assert_gpu_result_equal(q) + assert_gpu_result_equal(q, check_dtypes=check_dtypes, check_exact=False) + + +def test_quantile_invalid_q(df): + expr = pl.col("a").quantile(pl.col("a")) + q = df.select(expr) + assert_ir_translation_raises(q, NotImplementedError) + + @pytest.mark.parametrize( "op", [pl.Expr.min, pl.Expr.nan_min, pl.Expr.max, pl.Expr.nan_max] ) diff --git a/python/cudf_polars/tests/expressions/test_booleanfunction.py b/python/cudf_polars/tests/expressions/test_booleanfunction.py index 97421008669..2347021c40e 100644 --- a/python/cudf_polars/tests/expressions/test_booleanfunction.py +++ b/python/cudf_polars/tests/expressions/test_booleanfunction.py @@ -17,15 +17,11 @@ def has_nulls(request): return request.param -@pytest.mark.parametrize( - "ignore_nulls", - [ - pytest.param( - False, marks=pytest.mark.xfail(reason="No support for Kleene logic") - ), - True, - ], -) +@pytest.fixture(params=[False, True], ids=["include_nulls", "ignore_nulls"]) +def ignore_nulls(request): + return request.param + + def test_booleanfunction_reduction(ignore_nulls): ldf = pl.LazyFrame( { @@ -43,6 +39,25 @@ def test_booleanfunction_reduction(ignore_nulls): assert_gpu_result_equal(query) +@pytest.mark.parametrize("expr", [pl.Expr.any, pl.Expr.all]) +def test_booleanfunction_all_any_kleene(expr, ignore_nulls): + ldf = pl.LazyFrame( + { + "a": [False, None], + "b": [False, False], + "c": [False, True], + "d": [None, False], + "e": pl.Series([None, None], dtype=pl.Boolean()), + "f": [None, True], + "g": [True, False], + "h": [True, None], + "i": [True, True], + } + ) + q = ldf.select(expr(pl.col("*"), ignore_nulls=ignore_nulls)) + assert_gpu_result_equal(q) + + @pytest.mark.parametrize( "expr", [ @@ -54,14 +69,7 @@ def test_booleanfunction_reduction(ignore_nulls): ids=lambda f: f"{f.__name__}()", ) @pytest.mark.parametrize("has_nans", [False, True], ids=["no_nans", "nans"]) -def test_boolean_function_unary(request, expr, has_nans, has_nulls): - if has_nulls and expr in (pl.Expr.is_nan, pl.Expr.is_not_nan): - request.applymarker( - pytest.mark.xfail( - reason="Need to copy null mask since is_{not_}nan(null) => null" - ) - ) - +def test_boolean_function_unary(expr, has_nans, has_nulls): values: list[float | None] = [1, 2, 3, 4, 5] if has_nans: values[3] = float("nan") @@ -119,9 +127,7 @@ def test_boolean_isbetween(closed, bounds): "expr", [pl.any_horizontal("*"), pl.all_horizontal("*")], ids=["any", "all"] ) @pytest.mark.parametrize("wide", [False, True], ids=["narrow", "wide"]) -def test_boolean_horizontal(request, expr, has_nulls, wide): - if has_nulls: - request.applymarker(pytest.mark.xfail(reason="No support for Kleene logic")) +def test_boolean_horizontal(expr, has_nulls, wide): ldf = pl.LazyFrame( { "a": [False, False, False, False, False, True], @@ -164,6 +170,18 @@ def test_boolean_is_in(expr): assert_gpu_result_equal(q) +@pytest.mark.parametrize("expr", [pl.Expr.and_, pl.Expr.or_, pl.Expr.xor]) +def test_boolean_kleene_logic(expr): + ldf = pl.LazyFrame( + { + "a": [False, False, False, None, None, None, True, True, True], + "b": [False, None, True, False, None, True, False, None, True], + } + ) + q = ldf.select(expr(pl.col("a"), pl.col("b"))) + assert_gpu_result_equal(q) + + def test_boolean_is_in_raises_unsupported(): ldf = pl.LazyFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int64)}) q = ldf.select(pl.col("a").is_in(pl.lit(1, dtype=pl.Int32()))) diff --git a/python/cudf_polars/tests/expressions/test_datetime_basic.py b/python/cudf_polars/tests/expressions/test_datetime_basic.py index 218101bf87c..c6ea29ddd38 100644 --- a/python/cudf_polars/tests/expressions/test_datetime_basic.py +++ b/python/cudf_polars/tests/expressions/test_datetime_basic.py @@ -9,7 +9,11 @@ import polars as pl -from cudf_polars.testing.asserts import assert_gpu_result_equal +from cudf_polars.dsl.expr import TemporalFunction +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) @pytest.mark.parametrize( @@ -37,26 +41,97 @@ def test_datetime_dataframe_scan(dtype): assert_gpu_result_equal(query) +datetime_extract_fields = [ + "year", + "month", + "day", + "weekday", + "hour", + "minute", + "second", + "millisecond", + "microsecond", + "nanosecond", +] + + +@pytest.fixture( + ids=datetime_extract_fields, + params=[methodcaller(f) for f in datetime_extract_fields], +) +def field(request): + return request.param + + +def test_datetime_extract(field): + ldf = pl.LazyFrame( + { + "datetimes": pl.datetime_range( + datetime.datetime(2020, 1, 1), + datetime.datetime(2021, 12, 30), + "3mo14h15s11ms33us999ns", + eager=True, + ) + } + ) + + q = ldf.select(field(pl.col("datetimes").dt)) + + assert_gpu_result_equal(q) + + +def test_datetime_extra_unsupported(monkeypatch): + ldf = pl.LazyFrame( + { + "datetimes": pl.datetime_range( + datetime.datetime(2020, 1, 1), + datetime.datetime(2021, 12, 30), + "3mo14h15s11ms33us999ns", + eager=True, + ) + } + ) + + def unsupported_name_setter(self, value): + pass + + def unsupported_name_getter(self): + return "unsupported" + + monkeypatch.setattr( + TemporalFunction, + "name", + property(unsupported_name_getter, unsupported_name_setter), + ) + + q = ldf.select(pl.col("datetimes").dt.nanosecond()) + + assert_ir_translation_raises(q, NotImplementedError) + + @pytest.mark.parametrize( "field", [ methodcaller("year"), - pytest.param( - methodcaller("day"), - marks=pytest.mark.xfail(reason="day extraction not implemented"), - ), + methodcaller("month"), + methodcaller("day"), + methodcaller("weekday"), ], ) -def test_datetime_extract(field): +def test_date_extract(field): + ldf = pl.LazyFrame( + { + "dates": [ + datetime.date(2024, 1, 1), + datetime.date(2024, 10, 11), + ] + } + ) + ldf = pl.LazyFrame( {"dates": [datetime.date(2024, 1, 1), datetime.date(2024, 10, 11)]} ) - q = ldf.select(field(pl.col("dates").dt)) - with pytest.raises(AssertionError): - # polars produces int32, libcudf produces int16 for the year extraction - # libcudf can lose data here. - # https://github.com/rapidsai/cudf/issues/16196 - assert_gpu_result_equal(q) + q = ldf.select(field(pl.col("dates").dt)) - assert_gpu_result_equal(q, check_dtypes=False) + assert_gpu_result_equal(q) diff --git a/python/cudf_polars/tests/expressions/test_gather.py b/python/cudf_polars/tests/expressions/test_gather.py index 6bffa3e252c..f7c5d1bf2cd 100644 --- a/python/cudf_polars/tests/expressions/test_gather.py +++ b/python/cudf_polars/tests/expressions/test_gather.py @@ -6,7 +6,6 @@ import polars as pl -from cudf_polars import execute_with_cudf from cudf_polars.testing.asserts import assert_gpu_result_equal @@ -47,4 +46,4 @@ def test_gather_out_of_bounds(negative): query = ldf.select(pl.col("a").gather(pl.col("b"))) with pytest.raises(pl.exceptions.ComputeError): - query.collect(post_opt_callback=execute_with_cudf) + query.collect(engine="gpu") diff --git a/python/cudf_polars/tests/expressions/test_numeric_unaryops.py b/python/cudf_polars/tests/expressions/test_numeric_unaryops.py new file mode 100644 index 00000000000..ac3aecf88e6 --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_numeric_unaryops.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import numpy as np +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import assert_gpu_result_equal + + +@pytest.fixture( + params=[ + "sin", + "cos", + "tan", + "arcsin", + "arccos", + "arctan", + "sinh", + "cosh", + "tanh", + "arcsinh", + "arccosh", + "arctanh", + "exp", + "sqrt", + "cbrt", + "ceil", + "floor", + "abs", + ] +) +def op(request): + return request.param + + +@pytest.fixture(params=[pl.Int32, pl.Float32]) +def dtype(request): + return request.param + + +@pytest.fixture +def ldf(with_nulls, dtype): + values = [1, 2, 4, 5, -2, -4, 0] + if with_nulls: + values.append(None) + if dtype == pl.Float32: + values.append(-float("inf")) + values.append(float("nan")) + values.append(float("inf")) + elif dtype == pl.Int32: + iinfo = np.iinfo("int32") + values.append(iinfo.min) + values.append(iinfo.max) + return pl.LazyFrame( + { + "a": pl.Series(values, dtype=dtype), + "b": pl.Series([i - 4 for i in range(len(values))], dtype=pl.Float32), + } + ) + + +def test_unary(ldf, op): + expr = getattr(pl.col("a"), op)() + q = ldf.select(expr) + assert_gpu_result_equal(q, check_exact=False) + + +@pytest.mark.parametrize("base_literal", [False, True]) +@pytest.mark.parametrize("exponent_literal", [False, True]) +def test_pow(ldf, base_literal, exponent_literal): + base = pl.lit(2) if base_literal else pl.col("a") + exponent = pl.lit(-3, dtype=pl.Float32) if exponent_literal else pl.col("b") + + q = ldf.select(base.pow(exponent)) + + assert_gpu_result_equal(q, check_exact=False) + + +@pytest.mark.parametrize("natural", [True, False]) +def test_log(ldf, natural): + if natural: + expr = pl.col("a").log() + else: + expr = pl.col("a").log(10) + + q = ldf.select(expr) + + assert_gpu_result_equal(q, check_exact=False) diff --git a/python/cudf_polars/tests/expressions/test_stringfunction.py b/python/cudf_polars/tests/expressions/test_stringfunction.py index df08e15baa4..4f6850ac977 100644 --- a/python/cudf_polars/tests/expressions/test_stringfunction.py +++ b/python/cudf_polars/tests/expressions/test_stringfunction.py @@ -10,6 +10,7 @@ from cudf_polars import execute_with_cudf from cudf_polars.testing.asserts import ( + assert_collect_raises, assert_gpu_result_equal, assert_ir_translation_raises, ) @@ -152,3 +153,187 @@ def test_slice_column(slice_column_data): else: query = slice_column_data.select(pl.col("a").str.slice(pl.col("start"))) assert_ir_translation_raises(query, NotImplementedError) + + +@pytest.fixture +def to_datetime_data(): + return pl.LazyFrame( + { + "a": [ + "2021-01-01", + "2021-01-02", + "abcd", + ] + } + ) + + +@pytest.mark.parametrize("cache", [True, False], ids=lambda cache: f"{cache=}") +@pytest.mark.parametrize("strict", [True, False], ids=lambda strict: f"{strict=}") +@pytest.mark.parametrize("exact", [True, False], ids=lambda exact: f"{exact=}") +@pytest.mark.parametrize("format", ["%Y-%m-%d", None], ids=lambda format: f"{format=}") +def test_to_datetime(to_datetime_data, cache, strict, format, exact): + query = to_datetime_data.select( + pl.col("a").str.strptime( + pl.Datetime("ns"), format=format, cache=cache, strict=strict, exact=exact + ) + ) + if cache or format is None or not exact: + assert_ir_translation_raises(query, NotImplementedError) + elif strict: + assert_collect_raises( + query, + polars_except=pl.exceptions.InvalidOperationError, + cudf_except=pl.exceptions.ComputeError, + ) + else: + assert_gpu_result_equal(query) + + +@pytest.mark.parametrize( + "target, repl", + [("a", "a"), ("Wı", "☺"), ("FG", ""), ("doesnotexist", "blahblah")], # noqa: RUF001 +) +@pytest.mark.parametrize("n", [0, 3, -1]) +def test_replace_literal(ldf, target, repl, n): + query = ldf.select(pl.col("a").str.replace(target, repl, literal=True, n=n)) + assert_gpu_result_equal(query) + + +@pytest.mark.parametrize("target, repl", [("", ""), ("a", pl.col("a"))]) +def test_replace_literal_unsupported(ldf, target, repl): + query = ldf.select(pl.col("a").str.replace(target, repl, literal=True)) + assert_ir_translation_raises(query, NotImplementedError) + + +def test_replace_re(ldf): + query = ldf.select(pl.col("a").str.replace("A", "a", literal=False)) + assert_ir_translation_raises(query, NotImplementedError) + + +@pytest.mark.parametrize( + "target,repl", + [ + (["A", "de", "kLm", "awef"], "a"), + (["A", "de", "kLm", "awef"], ""), + (["A", "de", "kLm", "awef"], ["a", "b", "c", "d"]), + (["A", "de", "kLm", "awef"], ["a", "b", "c", ""]), + ( + pl.lit(pl.Series(["A", "de", "kLm", "awef"])), + pl.lit(pl.Series(["a", "b", "c", "d"])), + ), + ], +) +def test_replace_many(ldf, target, repl): + query = ldf.select(pl.col("a").str.replace_many(target, repl)) + + assert_gpu_result_equal(query) + + +@pytest.mark.parametrize( + "target,repl", + [(["A", ""], ["a", "b"]), (pl.col("a").drop_nulls(), pl.col("a").drop_nulls())], +) +def test_replace_many_notimplemented(ldf, target, repl): + query = ldf.select(pl.col("a").str.replace_many(target, repl)) + assert_ir_translation_raises(query, NotImplementedError) + + +def test_replace_many_ascii_case(ldf): + query = ldf.select( + pl.col("a").str.replace_many(["a", "b", "c"], "a", ascii_case_insensitive=True) + ) + + assert_ir_translation_raises(query, NotImplementedError) + + +_strip_data = [ + "AbC", + "123abc", + "", + " ", + None, + "aAaaaAAaa", + " ab c ", + "abc123", + " ", + "\tabc\t", + "\nabc\n", + "\r\nabc\r\n", + "\t\n abc \n\t", + "!@#$%^&*()", + " abc!!! ", + " abc\t\n!!! ", + "__abc__", + "abc\n\n", + "123abc456", + "abcxyzabc", +] + +strip_chars = [ + "a", + "", + " ", + "\t", + "\n", + "\r\n", + "!", + "@#", + "123", + "xyz", + "abc", + "__", + " \t\n", + "abc123", + None, +] + + +@pytest.fixture +def strip_ldf(): + return pl.DataFrame({"a": _strip_data}).lazy() + + +@pytest.fixture(params=strip_chars) +def to_strip(request): + return request.param + + +def test_strip_chars(strip_ldf, to_strip): + q = strip_ldf.select(pl.col("a").str.strip_chars(to_strip)) + assert_gpu_result_equal(q) + + +def test_strip_chars_start(strip_ldf, to_strip): + q = strip_ldf.select(pl.col("a").str.strip_chars_start(to_strip)) + assert_gpu_result_equal(q) + + +def test_strip_chars_end(strip_ldf, to_strip): + q = strip_ldf.select(pl.col("a").str.strip_chars_end(to_strip)) + assert_gpu_result_equal(q) + + +def test_strip_chars_column(strip_ldf): + q = strip_ldf.select(pl.col("a").str.strip_chars(pl.col("a"))) + assert_ir_translation_raises(q, NotImplementedError) + + +def test_invalid_regex_raises(): + df = pl.LazyFrame({"a": ["abc"]}) + + q = df.select(pl.col("a").str.contains(r"ab)", strict=True)) + + assert_collect_raises( + q, + polars_except=pl.exceptions.ComputeError, + cudf_except=pl.exceptions.ComputeError, + ) + + +@pytest.mark.parametrize("pattern", ["a{1000}", "a(?i:B)"]) +def test_unsupported_regex_raises(pattern): + df = pl.LazyFrame({"a": ["abc"]}) + + q = df.select(pl.col("a").str.contains(pattern, strict=True)) + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/test_config.py b/python/cudf_polars/tests/test_config.py index 5b4bba55552..3c3986be19b 100644 --- a/python/cudf_polars/tests/test_config.py +++ b/python/cudf_polars/tests/test_config.py @@ -6,6 +6,9 @@ import pytest import polars as pl +from polars.testing.asserts import assert_frame_equal + +import rmm from cudf_polars.dsl.ir import IR from cudf_polars.testing.asserts import ( @@ -32,3 +35,48 @@ def raise_unimplemented(self): ): # And ensure that collecting issues the correct warning. assert_gpu_result_equal(q) + + +def test_unsupported_config_raises(): + q = pl.LazyFrame({}) + + with pytest.raises(pl.exceptions.ComputeError): + q.collect(engine=pl.GPUEngine(unknown_key=True)) + + +@pytest.mark.parametrize("device", [-1, "foo"]) +def test_invalid_device_raises(device): + q = pl.LazyFrame({}) + with pytest.raises(pl.exceptions.ComputeError): + q.collect(engine=pl.GPUEngine(device=device)) + + +@pytest.mark.parametrize("mr", [1, object()]) +def test_invalid_memory_resource_raises(mr): + q = pl.LazyFrame({}) + with pytest.raises(pl.exceptions.ComputeError): + q.collect(engine=pl.GPUEngine(memory_resource=mr)) + + +def test_explicit_device_zero(): + q = pl.LazyFrame({"a": [1, 2, 3]}) + + result = q.collect(engine=pl.GPUEngine(device=0)) + assert_frame_equal(q.collect(), result) + + +def test_explicit_memory_resource(): + upstream = rmm.mr.CudaMemoryResource() + n_allocations = 0 + + def allocate(bytes, stream): + nonlocal n_allocations + n_allocations += 1 + return upstream.allocate(bytes, stream) + + mr = rmm.mr.CallbackMemoryResource(allocate, upstream.deallocate) + + q = pl.LazyFrame({"a": [1, 2, 3]}) + result = q.collect(engine=pl.GPUEngine(memory_resource=mr)) + assert_frame_equal(q.collect(), result) + assert n_allocations > 0 diff --git a/python/cudf_polars/tests/test_groupby.py b/python/cudf_polars/tests/test_groupby.py index a75825ef3d3..6f996e0e0ec 100644 --- a/python/cudf_polars/tests/test_groupby.py +++ b/python/cudf_polars/tests/test_groupby.py @@ -12,7 +12,6 @@ assert_gpu_result_equal, assert_ir_translation_raises, ) -from cudf_polars.utils import versions @pytest.fixture @@ -31,6 +30,7 @@ def df(): params=[ [pl.col("key1")], [pl.col("key2")], + [pl.col("key1"), pl.lit(1)], [pl.col("key1") * pl.col("key2")], [pl.col("key1"), pl.col("key2")], [pl.col("key1") == pl.col("key2")], @@ -52,6 +52,7 @@ def keys(request): [(pl.col("float") - pl.lit(2)).max()], [pl.col("float").sum().round(decimals=1)], [pl.col("float").round(decimals=1).sum()], + [pl.col("int").first(), pl.col("float").last()], ], ids=lambda aggs: "-".join(map(str, aggs)), ) @@ -60,15 +61,7 @@ def exprs(request): @pytest.fixture( - params=[ - False, - pytest.param( - True, - marks=pytest.mark.xfail( - reason="Maintaining order in groupby not implemented" - ), - ), - ], + params=[False, True], ids=["no_maintain_order", "maintain_order"], ) def maintain_order(request): @@ -98,15 +91,10 @@ def test_groupby_sorted_keys(df: pl.LazyFrame, keys, exprs): # Multiple keys don't do sorting qsorted = q.sort(*sort_keys) if len(keys) > 1: - with pytest.raises(AssertionError): - # https://github.com/pola-rs/polars/issues/17556 - assert_gpu_result_equal(q, check_exact=False) - if versions.POLARS_VERSION_LT_12 and schema[sort_keys[1]] == pl.Boolean(): - # https://github.com/pola-rs/polars/issues/17557 - with pytest.raises(AssertionError): - assert_gpu_result_equal(qsorted, check_exact=False) - else: - assert_gpu_result_equal(qsorted, check_exact=False) + # https://github.com/pola-rs/polars/issues/17556 + # Can't assert that the query without post-sorting fails, + # since it _might_ pass. + assert_gpu_result_equal(qsorted, check_exact=False) elif schema[sort_keys[0]] == pl.Boolean(): # Boolean keys don't do sorting, so we get random order assert_gpu_result_equal(qsorted, check_exact=False) @@ -133,6 +121,21 @@ def test_groupby_unsupported(df, expr): assert_ir_translation_raises(q, NotImplementedError) +def test_groupby_null_keys(maintain_order): + df = pl.LazyFrame( + { + "key": pl.Series([1, float("nan"), 2, None, 2, None], dtype=pl.Float64()), + "value": [-1, 2, 1, 2, 3, 4], + } + ) + + q = df.group_by("key", maintain_order=maintain_order).agg(pl.col("value").min()) + if not maintain_order: + q = q.sort("key") + + assert_gpu_result_equal(q) + + @pytest.mark.xfail(reason="https://github.com/pola-rs/polars/issues/17513") def test_groupby_minmax_with_nan(): df = pl.LazyFrame( @@ -159,15 +162,7 @@ def test_groupby_nan_minmax_raises(op): @pytest.mark.parametrize( "key", - [ - pytest.param( - 1, - marks=pytest.mark.xfail( - versions.POLARS_VERSION_GE_121, reason="polars 1.2.1 disallows this" - ), - ), - pl.col("key1"), - ], + [1, pl.col("key1")], ) @pytest.mark.parametrize( "expr", @@ -183,3 +178,12 @@ def test_groupby_literal_in_agg(df, key, expr): # so just sort by the group key q = df.group_by(key).agg(expr).sort(key, maintain_order=True) assert_gpu_result_equal(q) + + +@pytest.mark.parametrize( + "expr", + [pl.col("int").unique(), pl.col("int").drop_nulls(), pl.col("int").cum_max()], +) +def test_groupby_unary_non_pointwise_raises(df, expr): + q = df.group_by("key1").agg(expr) + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/test_groupby_dynamic.py b/python/cudf_polars/tests/test_groupby_dynamic.py new file mode 100644 index 00000000000..38b3ce74ac5 --- /dev/null +++ b/python/cudf_polars/tests/test_groupby_dynamic.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +from datetime import datetime + +import polars as pl + +from cudf_polars.testing.asserts import assert_ir_translation_raises + + +def test_groupby_dynamic_raises(): + df = pl.LazyFrame( + { + "dt": [ + datetime(2021, 12, 31, 0, 0, 0), + datetime(2022, 1, 1, 0, 0, 1), + datetime(2022, 3, 31, 0, 0, 1), + datetime(2022, 4, 1, 0, 0, 1), + ] + } + ) + + q = ( + df.sort("dt") + .group_by_dynamic("dt", every="1q") + .agg(pl.col("dt").count().alias("num_values")) + ) + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/test_join.py b/python/cudf_polars/tests/test_join.py index 1e880cdc6de..7d9ec98db97 100644 --- a/python/cudf_polars/tests/test_join.py +++ b/python/cudf_polars/tests/test_join.py @@ -17,7 +17,7 @@ def join_nulls(request): return request.param -@pytest.fixture(params=["inner", "left", "semi", "anti", "full"]) +@pytest.fixture(params=["inner", "left", "right", "semi", "anti", "full"]) def how(request): return request.param diff --git a/python/cudf_polars/tests/test_mapfunction.py b/python/cudf_polars/tests/test_mapfunction.py index 77032108e6f..e895f27f637 100644 --- a/python/cudf_polars/tests/test_mapfunction.py +++ b/python/cudf_polars/tests/test_mapfunction.py @@ -61,3 +61,48 @@ def test_rename_columns(mapping): q = df.rename(mapping) assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("index", [None, ["a"], ["d", "a"]]) +@pytest.mark.parametrize("variable_name", [None, "names"]) +@pytest.mark.parametrize("value_name", [None, "unpivoted"]) +def test_unpivot(index, variable_name, value_name): + df = pl.LazyFrame( + { + "a": ["x", "y", "z"], + "b": pl.Series([1, 3, 5], dtype=pl.Int16), + "c": pl.Series([2, 4, 6], dtype=pl.Float32), + "d": ["a", "b", "c"], + } + ) + q = df.unpivot( + ["c", "b"], index=index, variable_name=variable_name, value_name=value_name + ) + + assert_gpu_result_equal(q) + + +def test_unpivot_defaults(): + df = pl.LazyFrame( + { + "a": pl.Series([11, 12, 13], dtype=pl.UInt16), + "b": pl.Series([1, 3, 5], dtype=pl.Int16), + "c": pl.Series([2, 4, 6], dtype=pl.Float32), + "d": ["a", "b", "c"], + } + ) + q = df.unpivot(index="d") + assert_gpu_result_equal(q) + + +def test_unpivot_unsupported_cast_raises(): + df = pl.LazyFrame( + { + "a": ["x", "y", "z"], + "b": pl.Series([1, 3, 5], dtype=pl.Int16), + } + ) + + q = df.unpivot(["a", "b"]) + + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/test_python_scan.py b/python/cudf_polars/tests/test_python_scan.py index fd8453b77c4..0cda89474a8 100644 --- a/python/cudf_polars/tests/test_python_scan.py +++ b/python/cudf_polars/tests/test_python_scan.py @@ -8,7 +8,9 @@ def test_python_scan(): - def source(with_columns, predicate, nrows): + def source(with_columns, predicate, nrows, *batch_size): + # PythonScan interface changes between 1.3 and 1.4 to add an + # extra batch_size argument return pl.DataFrame({"a": pl.Series([1, 2, 3], dtype=pl.Int8())}) q = pl.LazyFrame._scan_python_function({"a": pl.Int8}, source, pyarrow=False) diff --git a/python/cudf_polars/tests/test_scan.py b/python/cudf_polars/tests/test_scan.py index 64acbb076ed..792b136acd8 100644 --- a/python/cudf_polars/tests/test_scan.py +++ b/python/cudf_polars/tests/test_scan.py @@ -12,7 +12,6 @@ assert_gpu_result_equal, assert_ir_translation_raises, ) -from cudf_polars.utils import versions @pytest.fixture( @@ -58,6 +57,22 @@ def mask(request): return request.param +@pytest.fixture( + params=[ + None, + (1, 1), + ], + ids=[ + "no-slice", + "slice-second", + ], +) +def slice(request): + # For use in testing that we handle + # polars slice pushdown correctly + return request.param + + def make_source(df, path, format): """ Writes the passed polars df to a file of @@ -79,7 +94,9 @@ def make_source(df, path, format): ("parquet", pl.scan_parquet), ], ) -def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, request): +def test_scan( + tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, slice, request +): name, offset = row_index make_source(df, tmp_path / "file", format) request.applymarker( @@ -94,21 +111,23 @@ def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, r row_index_offset=offset, n_rows=n_rows, ) + if slice is not None: + q = q.slice(*slice) if mask is not None: q = q.filter(mask) if columns is not None: q = q.select(*columns) - polars_collect_kwargs = {} - if versions.POLARS_VERSION_LT_12: - # https://github.com/pola-rs/polars/issues/17553 - polars_collect_kwargs = {"projection_pushdown": False} - assert_gpu_result_equal( - q, - polars_collect_kwargs=polars_collect_kwargs, - # This doesn't work in polars < 1.2 since the row-index - # is in the wrong order in previous polars releases - check_column_order=versions.POLARS_VERSION_LT_12, - ) + assert_gpu_result_equal(q) + + +def test_negative_slice_pushdown_raises(tmp_path): + df = pl.DataFrame({"a": [1, 2, 3]}) + + df.write_parquet(tmp_path / "df.parquet") + q = pl.scan_parquet(tmp_path / "df.parquet") + # Take the last row + q = q.slice(-1, 1) + assert_ir_translation_raises(q, NotImplementedError) def test_scan_unsupported_raises(tmp_path): @@ -127,10 +146,6 @@ def test_scan_ndjson_nrows_notimplemented(tmp_path, df): assert_ir_translation_raises(q, NotImplementedError) -@pytest.mark.xfail( - versions.POLARS_VERSION_LT_11, - reason="https://github.com/pola-rs/polars/issues/15730", -) def test_scan_row_index_projected_out(tmp_path): df = pl.DataFrame({"a": [1, 2, 3]}) @@ -169,15 +184,25 @@ def test_scan_csv_column_renames_projection_schema(tmp_path): ("test*.csv", False), ], ) -def test_scan_csv_multi(tmp_path, filename, glob): +@pytest.mark.parametrize( + "nrows_skiprows", + [ + (None, 0), + (1, 1), + (3, 0), + (4, 2), + ], +) +def test_scan_csv_multi(tmp_path, filename, glob, nrows_skiprows): + n_rows, skiprows = nrows_skiprows with (tmp_path / "test1.csv").open("w") as f: - f.write("""foo,bar,baz\n1,2\n3,4,5""") + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") with (tmp_path / "test2.csv").open("w") as f: - f.write("""foo,bar,baz\n1,2\n3,4,5""") + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") with (tmp_path / "test*.csv").open("w") as f: - f.write("""foo,bar,baz\n1,2\n3,4,5""") + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") os.chdir(tmp_path) - q = pl.scan_csv(filename, glob=glob) + q = pl.scan_csv(filename, glob=glob, n_rows=n_rows, skip_rows=skiprows) assert_gpu_result_equal(q) @@ -280,3 +305,24 @@ def test_scan_ndjson_unsupported(df, tmp_path): make_source(df, tmp_path / "file", "ndjson") q = pl.scan_ndjson(tmp_path / "file", ignore_errors=True) assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_parquet_nested_null_raises(tmp_path): + df = pl.DataFrame({"a": pl.Series([None], dtype=pl.List(pl.Null))}) + + df.write_parquet(tmp_path / "file.pq") + + q = pl.scan_parquet(tmp_path / "file.pq") + + assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_parquet_only_row_index_raises(df, tmp_path): + make_source(df, tmp_path / "file", "parquet") + q = pl.scan_parquet(tmp_path / "file", row_index_name="index").select("index") + assert_ir_translation_raises(q, NotImplementedError) + + +def test_scan_hf_url_raises(): + q = pl.scan_csv("hf://datasets/scikit-learn/iris/Iris.csv") + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/test_sort.py b/python/cudf_polars/tests/test_sort.py index ecc02efd967..cfa8e5ff9b9 100644 --- a/python/cudf_polars/tests/test_sort.py +++ b/python/cudf_polars/tests/test_sort.py @@ -13,10 +13,7 @@ "sort_keys", [ (pl.col("a"),), - pytest.param( - (pl.col("d").abs(),), - marks=pytest.mark.xfail(reason="abs not yet implemented"), - ), + (pl.col("d").abs(),), (pl.col("a"), pl.col("d")), (pl.col("b"),), ], diff --git a/python/cudf_polars/tests/testing/test_asserts.py b/python/cudf_polars/tests/testing/test_asserts.py index 5bc2fe1efb7..8e7f1a09d9b 100644 --- a/python/cudf_polars/tests/testing/test_asserts.py +++ b/python/cudf_polars/tests/testing/test_asserts.py @@ -7,7 +7,10 @@ import polars as pl +from cudf_polars.containers import DataFrame +from cudf_polars.dsl.ir import Select from cudf_polars.testing.asserts import ( + assert_collect_raises, assert_gpu_result_equal, assert_ir_translation_raises, ) @@ -26,10 +29,62 @@ def test_translation_assert_raises(): class E(Exception): pass - unsupported = df.group_by("a").agg(pl.col("a").cum_max().alias("b")) + unsupported = df.group_by("a").agg(pl.col("a").upper_bound().alias("b")) # Unsupported query should raise NotImplementedError assert_ir_translation_raises(unsupported, NotImplementedError) with pytest.raises(AssertionError): # This should fail, because we can't translate this query, but it doesn't raise E. assert_ir_translation_raises(unsupported, E) + + +def test_collect_assert_raises(monkeypatch): + df = pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) + + with pytest.raises(AssertionError): + # This should raise, because polars CPU can run this query + assert_collect_raises( + df, + polars_except=pl.exceptions.InvalidOperationError, + cudf_except=pl.exceptions.InvalidOperationError, + ) + + # Here's an invalid query that gets caught at IR optimisation time. + q = df.select(pl.col("a") * pl.col("b")) + + # This exception is raised in preprocessing, so is the same for + # both CPU and GPU engines. + assert_collect_raises( + q, + polars_except=pl.exceptions.InvalidOperationError, + cudf_except=pl.exceptions.InvalidOperationError, + ) + + with pytest.raises(AssertionError): + # This should raise because the expected GPU error is wrong + assert_collect_raises( + q, + polars_except=pl.exceptions.InvalidOperationError, + cudf_except=NotImplementedError, + ) + + with pytest.raises(AssertionError): + # This should raise because the expected CPU error is wrong + assert_collect_raises( + q, + polars_except=NotImplementedError, + cudf_except=pl.exceptions.InvalidOperationError, + ) + + with monkeypatch.context() as m: + m.setattr(Select, "evaluate", lambda self, cache: DataFrame([])) + # This query should fail, but we monkeypatch a bad + # implementation of Select which "succeeds" to check that our + # assertion notices this case. + q = df.select(pl.col("a") + pl.Series([1, 2])) + with pytest.raises(AssertionError): + assert_collect_raises( + q, + polars_except=pl.exceptions.ComputeError, + cudf_except=pl.exceptions.ComputeError, + ) diff --git a/python/pylibcudf/pylibcudf/datetime.pyx b/python/pylibcudf/pylibcudf/datetime.pyx index 0ddc68bcb9d..e8e0caaf42d 100644 --- a/python/pylibcudf/pylibcudf/datetime.pyx +++ b/python/pylibcudf/pylibcudf/datetime.pyx @@ -2,7 +2,19 @@ from libcpp.memory cimport unique_ptr from libcpp.utility cimport move from pylibcudf.libcudf.column.column cimport column -from pylibcudf.libcudf.datetime cimport extract_year as cpp_extract_year +from pylibcudf.libcudf.datetime cimport ( + day_of_year as cpp_day_of_year, + extract_day as cpp_extract_day, + extract_hour as cpp_extract_hour, + extract_microsecond_fraction as cpp_extract_microsecond_fraction, + extract_millisecond_fraction as cpp_extract_millisecond_fraction, + extract_minute as cpp_extract_minute, + extract_month as cpp_extract_month, + extract_nanosecond_fraction as cpp_extract_nanosecond_fraction, + extract_second as cpp_extract_second, + extract_weekday as cpp_extract_weekday, + extract_year as cpp_extract_year, +) from .column cimport Column @@ -28,3 +40,42 @@ cpdef Column extract_year( with nogil: result = move(cpp_extract_year(values.view())) return Column.from_libcudf(move(result)) + + +def extract_datetime_component(Column col, str field): + + cdef unique_ptr[column] c_result + + with nogil: + if field == "year": + c_result = move(cpp_extract_year(col.view())) + elif field == "month": + c_result = move(cpp_extract_month(col.view())) + elif field == "day": + c_result = move(cpp_extract_day(col.view())) + elif field == "weekday": + c_result = move(cpp_extract_weekday(col.view())) + elif field == "hour": + c_result = move(cpp_extract_hour(col.view())) + elif field == "minute": + c_result = move(cpp_extract_minute(col.view())) + elif field == "second": + c_result = move(cpp_extract_second(col.view())) + elif field == "millisecond": + c_result = move( + cpp_extract_millisecond_fraction(col.view()) + ) + elif field == "microsecond": + c_result = move( + cpp_extract_microsecond_fraction(col.view()) + ) + elif field == "nanosecond": + c_result = move( + cpp_extract_nanosecond_fraction(col.view()) + ) + elif field == "day_of_year": + c_result = move(cpp_day_of_year(col.view())) + else: + raise ValueError(f"Invalid datetime field: '{field}'") + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/CMakeLists.txt b/python/pylibcudf/pylibcudf/libcudf/strings/CMakeLists.txt index bd6e2e0af02..abf4357f862 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/libcudf/strings/CMakeLists.txt @@ -12,7 +12,7 @@ # the License. # ============================================================================= -set(cython_sources char_types.pyx regex_flags.pyx) +set(cython_sources char_types.pyx regex_flags.pyx side_type.pyx) set(linked_libraries cudf::cudf) diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd b/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd index 3a89299f11a..019ff3f17ba 100644 --- a/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd +++ b/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pxd @@ -1,10 +1,10 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. from libc.stdint cimport int32_t cdef extern from "cudf/strings/side_type.hpp" namespace "cudf::strings" nogil: - ctypedef enum side_type: + cpdef enum class side_type(int32_t): LEFT 'cudf::strings::side_type::LEFT' RIGHT 'cudf::strings::side_type::RIGHT' BOTH 'cudf::strings::side_type::BOTH' diff --git a/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pyx b/python/pylibcudf/pylibcudf/libcudf/strings/side_type.pyx new file mode 100644 index 00000000000..e69de29bb2d diff --git a/python/pylibcudf/pylibcudf/strings/CMakeLists.txt b/python/pylibcudf/pylibcudf/strings/CMakeLists.txt index d3065cf8667..8b4fbb1932f 100644 --- a/python/pylibcudf/pylibcudf/strings/CMakeLists.txt +++ b/python/pylibcudf/pylibcudf/strings/CMakeLists.txt @@ -12,8 +12,9 @@ # the License. # ============================================================================= -set(cython_sources capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx - regex_flags.pyx regex_program.pyx repeat.pyx replace.pyx slice.pyx +set(cython_sources + capitalize.pyx case.pyx char_types.pyx contains.pyx extract.pyx find.pyx regex_flags.pyx + regex_program.pyx repeat.pyx replace.pyx side_type.pyx slice.pyx strip.pyx ) set(linked_libraries cudf::cudf) @@ -22,3 +23,5 @@ rapids_cython_create_modules( SOURCE_FILES "${cython_sources}" LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX pylibcudf_strings_ ASSOCIATED_TARGETS cudf ) + +add_subdirectory(convert) diff --git a/python/pylibcudf/pylibcudf/strings/__init__.pxd b/python/pylibcudf/pylibcudf/strings/__init__.pxd index 6848c8e6e86..4867d944dc7 100644 --- a/python/pylibcudf/pylibcudf/strings/__init__.pxd +++ b/python/pylibcudf/pylibcudf/strings/__init__.pxd @@ -5,10 +5,13 @@ from . cimport ( case, char_types, contains, + convert, extract, find, regex_flags, regex_program, replace, slice, + strip, ) +from .side_type cimport side_type diff --git a/python/pylibcudf/pylibcudf/strings/__init__.py b/python/pylibcudf/pylibcudf/strings/__init__.py index bba86e818cc..a3bef64d19f 100644 --- a/python/pylibcudf/pylibcudf/strings/__init__.py +++ b/python/pylibcudf/pylibcudf/strings/__init__.py @@ -5,6 +5,7 @@ case, char_types, contains, + convert, extract, find, regex_flags, @@ -12,4 +13,6 @@ repeat, replace, slice, + strip, ) +from .side_type import SideType diff --git a/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt b/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt new file mode 100644 index 00000000000..175c9b3738e --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/CMakeLists.txt @@ -0,0 +1,22 @@ +# ============================================================================= +# Copyright (c) 2024, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +set(cython_sources convert_durations.pyx convert_datetime.pyx) + +set(linked_libraries cudf::cudf) +rapids_cython_create_modules( + CXX + SOURCE_FILES "${cython_sources}" + LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX pylibcudf_strings_ ASSOCIATED_TARGETS cudf +) diff --git a/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd b/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd new file mode 100644 index 00000000000..05324cb49df --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/__init__.pxd @@ -0,0 +1,2 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +from . cimport convert_datetime, convert_durations diff --git a/python/pylibcudf/pylibcudf/strings/convert/__init__.py b/python/pylibcudf/pylibcudf/strings/convert/__init__.py new file mode 100644 index 00000000000..d803399d53c --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. +from . import convert_datetime, convert_durations diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd new file mode 100644 index 00000000000..07c84d263d6 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pxd @@ -0,0 +1,18 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.string cimport string +from pylibcudf.column cimport Column +from pylibcudf.types cimport DataType + + +cpdef Column to_timestamps( + Column input, + DataType timestamp_type, + const string& format +) + +cpdef Column from_timestamps( + Column input, + const string& format, + Column input_strings_names +) diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx new file mode 100644 index 00000000000..fcacb096f87 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_datetime.pyx @@ -0,0 +1,56 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.string cimport string +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.strings.convert cimport ( + convert_datetime as cpp_convert_datetime, +) + +from pylibcudf.types import DataType + + +cpdef Column to_timestamps( + Column input, + DataType timestamp_type, + const string& format +): + cdef unique_ptr[column] c_result + with nogil: + c_result = cpp_convert_datetime.to_timestamps( + input.view(), + timestamp_type.c_obj, + format + ) + + return Column.from_libcudf(move(c_result)) + +cpdef Column from_timestamps( + Column input, + const string& format, + Column input_strings_names +): + cdef unique_ptr[column] c_result + with nogil: + c_result = cpp_convert_datetime.from_timestamps( + input.view(), + format, + input_strings_names.view() + ) + + return Column.from_libcudf(move(c_result)) + +cpdef Column is_timestamp( + Column input, + const string& format +): + cdef unique_ptr[column] c_result + with nogil: + c_result = cpp_convert_datetime.is_timestamp( + input.view(), + format + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_durations.pxd b/python/pylibcudf/pylibcudf/strings/convert/convert_durations.pxd new file mode 100644 index 00000000000..ac11b8959ed --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_durations.pxd @@ -0,0 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.string cimport string +from pylibcudf.column cimport Column +from pylibcudf.types cimport DataType + + +cpdef Column to_durations( + Column input, + DataType duration_type, + const string& format +) + +cpdef Column from_durations( + Column input, + const string& format +) diff --git a/python/pylibcudf/pylibcudf/strings/convert/convert_durations.pyx b/python/pylibcudf/pylibcudf/strings/convert/convert_durations.pyx new file mode 100644 index 00000000000..f3e0b7c9c8e --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/convert/convert_durations.pyx @@ -0,0 +1,41 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from libcpp.memory cimport unique_ptr +from libcpp.string cimport string +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.strings.convert cimport ( + convert_durations as cpp_convert_durations, +) + +from pylibcudf.types import DataType + + +cpdef Column to_durations( + Column input, + DataType duration_type, + const string& format +): + cdef unique_ptr[column] c_result + with nogil: + c_result = cpp_convert_durations.to_durations( + input.view(), + duration_type.c_obj, + format + ) + + return Column.from_libcudf(move(c_result)) + +cpdef Column from_durations( + Column input, + const string& format +): + cdef unique_ptr[column] c_result + with nogil: + c_result = cpp_convert_durations.from_durations( + input.view(), + format + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/strings/side_type.pxd b/python/pylibcudf/pylibcudf/strings/side_type.pxd new file mode 100644 index 00000000000..34b7a580380 --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/side_type.pxd @@ -0,0 +1,3 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.libcudf.strings.side_type cimport side_type diff --git a/python/pylibcudf/pylibcudf/strings/side_type.pyx b/python/pylibcudf/pylibcudf/strings/side_type.pyx new file mode 100644 index 00000000000..acdc7d6ff1f --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/side_type.pyx @@ -0,0 +1,4 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.libcudf.strings.side_type import \ + side_type as SideType # no-cython-lint diff --git a/python/pylibcudf/pylibcudf/strings/strip.pxd b/python/pylibcudf/pylibcudf/strings/strip.pxd new file mode 100644 index 00000000000..8bbe4753edd --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/strip.pxd @@ -0,0 +1,12 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from pylibcudf.column cimport Column +from pylibcudf.scalar cimport Scalar +from pylibcudf.strings.side_type cimport side_type + + +cpdef Column strip( + Column input, + side_type side=*, + Scalar to_strip=* +) diff --git a/python/pylibcudf/pylibcudf/strings/strip.pyx b/python/pylibcudf/pylibcudf/strings/strip.pyx new file mode 100644 index 00000000000..429a23c3cdf --- /dev/null +++ b/python/pylibcudf/pylibcudf/strings/strip.pyx @@ -0,0 +1,60 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from cython.operator cimport dereference +from libcpp.memory cimport unique_ptr +from libcpp.utility cimport move +from pylibcudf.column cimport Column +from pylibcudf.libcudf.column.column cimport column +from pylibcudf.libcudf.scalar.scalar cimport string_scalar +from pylibcudf.libcudf.scalar.scalar_factories cimport ( + make_string_scalar as cpp_make_string_scalar, +) +from pylibcudf.libcudf.strings cimport strip as cpp_strip +from pylibcudf.scalar cimport Scalar +from pylibcudf.strings.side_type cimport side_type + + +cpdef Column strip( + Column input, + side_type side=side_type.BOTH, + Scalar to_strip=None +): + """Removes the specified characters from the beginning + or end (or both) of each string. + + For details, see :cpp:func:`cudf::strings::strip`. + + Parameters + ---------- + input : Column + Strings column for this operation + side : SideType, default SideType.BOTH + Indicates characters are to be stripped from the beginning, + end, or both of each string; Default is both + to_strip : Scalar + UTF-8 encoded characters to strip from each string; + Default is empty string which indicates strip whitespace characters + + Returns + ------- + pylibcudf.Column + New strings column. + """ + + if to_strip is None: + to_strip = Scalar.from_libcudf( + cpp_make_string_scalar("".encode()) + ) + + cdef unique_ptr[column] c_result + cdef string_scalar* cpp_to_strip + cpp_to_strip = (to_strip.c_obj.get()) + + with nogil: + c_result = cpp_strip.strip( + input.view(), + side, + dereference(cpp_to_strip) + ) + + return Column.from_libcudf(move(c_result)) diff --git a/python/pylibcudf/pylibcudf/tests/test_datetime.py b/python/pylibcudf/pylibcudf/tests/test_datetime.py index d3aa6101e2d..89c96829e71 100644 --- a/python/pylibcudf/pylibcudf/tests/test_datetime.py +++ b/python/pylibcudf/pylibcudf/tests/test_datetime.py @@ -1,6 +1,7 @@ # Copyright (c) 2024, NVIDIA CORPORATION. import datetime +import functools import pyarrow as pa import pyarrow.compute as pc @@ -10,7 +11,7 @@ @pytest.fixture -def column(has_nulls): +def date_column(has_nulls): values = [ datetime.date(1999, 1, 1), datetime.date(2024, 10, 12), @@ -22,9 +23,41 @@ def column(has_nulls): return plc.interop.from_arrow(pa.array(values, type=pa.date32())) -def test_extract_year(column): - got = plc.datetime.extract_year(column) +@pytest.fixture(scope="module", params=["s", "ms", "us", "ns"]) +def datetime_column(has_nulls, request): + values = [ + datetime.datetime(1999, 1, 1), + datetime.datetime(2024, 10, 12), + datetime.datetime(1970, 1, 1), + datetime.datetime(2260, 1, 1), + datetime.datetime(2024, 2, 29, 3, 14, 15), + datetime.datetime(2024, 2, 29, 3, 14, 15, 999), + ] + if has_nulls: + values[2] = None + return plc.interop.from_arrow( + pa.array(values, type=pa.timestamp(request.param)) + ) + + +@pytest.mark.parametrize( + "component, pc_fun", + [ + ("year", pc.year), + ("month", pc.month), + ("day", pc.day), + ("weekday", functools.partial(pc.day_of_week, count_from_zero=False)), + ("hour", pc.hour), + ("minute", pc.minute), + ("second", pc.second), + ("millisecond", pc.millisecond), + ("microsecond", pc.microsecond), + ("nanosecond", pc.nanosecond), + ], +) +def test_extraction(datetime_column, component, pc_fun): + got = plc.datetime.extract_datetime_component(datetime_column, component) # libcudf produces an int16, arrow produces an int64 - expect = pc.year(plc.interop.to_arrow(column)).cast(pa.int16()) + expect = pc_fun(plc.interop.to_arrow(datetime_column)).cast(pa.int16()) assert_column_eq(expect, got) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_convert.py b/python/pylibcudf/pylibcudf/tests/test_string_convert.py new file mode 100644 index 00000000000..e9e95459d0e --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_convert.py @@ -0,0 +1,85 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +from datetime import datetime + +import pyarrow as pa +import pylibcudf as plc +import pytest +from utils import assert_column_eq + + +@pytest.fixture( + scope="module", + params=[ + pa.timestamp("ns"), + pa.timestamp("us"), + pa.timestamp("ms"), + pa.timestamp("s"), + ], +) +def timestamp_type(request): + return request.param + + +@pytest.fixture( + scope="module", + params=[ + pa.duration("ns"), + pa.duration("us"), + pa.duration("ms"), + pa.duration("s"), + ], +) +def duration_type(request): + return request.param + + +@pytest.fixture(scope="module") +def pa_timestamp_col(): + return pa.array(["2011-01-01", "2011-01-02", "2011-01-03"]) + + +@pytest.fixture(scope="module") +def pa_duration_col(): + return pa.array(["05:20:25"]) + + +@pytest.fixture(scope="module") +def plc_timestamp_col(pa_timestamp_col): + return plc.interop.from_arrow(pa_timestamp_col) + + +@pytest.fixture(scope="module") +def plc_duration_col(pa_duration_col): + return plc.interop.from_arrow(pa_duration_col) + + +@pytest.mark.parametrize("format", ["%Y-%m-%d"]) +def test_to_datetime( + pa_timestamp_col, plc_timestamp_col, timestamp_type, format +): + expect = pa.compute.strptime(pa_timestamp_col, format, timestamp_type.unit) + got = plc.strings.convert.convert_datetime.to_timestamps( + plc_timestamp_col, + plc.interop.from_arrow(timestamp_type), + format.encode(), + ) + assert_column_eq(expect, got) + + +@pytest.mark.parametrize("format", ["%H:%M:%S"]) +def test_to_duration(pa_duration_col, plc_duration_col, duration_type, format): + def to_timedelta(duration_str): + date = datetime.strptime(duration_str, format) + return date - datetime(1900, 1, 1) # "%H:%M:%S" zero date + + expect = pa.array([to_timedelta(d.as_py()) for d in pa_duration_col]).cast( + duration_type + ) + + got = plc.strings.convert.convert_durations.to_durations( + plc_duration_col, + plc.interop.from_arrow(duration_type), + format.encode(), + ) + assert_column_eq(expect, got) diff --git a/python/pylibcudf/pylibcudf/tests/test_string_strip.py b/python/pylibcudf/pylibcudf/tests/test_string_strip.py new file mode 100644 index 00000000000..005e5e4a405 --- /dev/null +++ b/python/pylibcudf/pylibcudf/tests/test_string_strip.py @@ -0,0 +1,122 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. + +import pyarrow as pa +import pylibcudf as plc +import pytest +from utils import assert_column_eq + +data_strings = [ + "AbC", + "123abc", + "", + " ", + None, + "aAaaaAAaa", + " ab c ", + "abc123", + " ", + "\tabc\t", + "\nabc\n", + "\r\nabc\r\n", + "\t\n abc \n\t", + "!@#$%^&*()", + " abc!!! ", + " abc\t\n!!! ", + "__abc__", + "abc\n\n", + "123abc456", + "abcxyzabc", +] + +strip_chars = [ + "a", + "", + " ", + "\t", + "\n", + "\r\n", + "!", + "@#", + "123", + "xyz", + "abc", + "__", + " \t\n", + "abc123", +] + + +@pytest.fixture +def pa_col(): + return pa.array(data_strings, type=pa.string()) + + +@pytest.fixture +def plc_col(pa_col): + return plc.interop.from_arrow(pa_col) + + +@pytest.fixture(params=strip_chars) +def pa_char(request): + return pa.scalar(request.param, type=pa.string()) + + +@pytest.fixture +def plc_char(pa_char): + return plc.interop.from_arrow(pa_char) + + +def test_strip(pa_col, plc_col, pa_char, plc_char): + def strip_string(st, char): + if st is None: + return None + + elif char == "": + return st.strip() + return st.strip(char) + + expected = pa.array( + [strip_string(x, pa_char.as_py()) for x in pa_col.to_pylist()], + type=pa.string(), + ) + + got = plc.strings.strip.strip(plc_col, plc.strings.SideType.BOTH, plc_char) + assert_column_eq(expected, got) + + +def test_strip_right(pa_col, plc_col, pa_char, plc_char): + def strip_string(st, char): + if st is None: + return None + + elif char == "": + return st.rstrip() + return st.rstrip(char) + + expected = pa.array( + [strip_string(x, pa_char.as_py()) for x in pa_col.to_pylist()], + type=pa.string(), + ) + + got = plc.strings.strip.strip( + plc_col, plc.strings.SideType.RIGHT, plc_char + ) + assert_column_eq(expected, got) + + +def test_strip_left(pa_col, plc_col, pa_char, plc_char): + def strip_string(st, char): + if st is None: + return None + + elif char == "": + return st.lstrip() + return st.lstrip(char) + + expected = pa.array( + [strip_string(x, pa_char.as_py()) for x in pa_col.to_pylist()], + type=pa.string(), + ) + + got = plc.strings.strip.strip(plc_col, plc.strings.SideType.LEFT, plc_char) + assert_column_eq(expected, got)