diff --git a/cpp/cmake/thirdparty/get_cucollections.cmake b/cpp/cmake/thirdparty/get_cucollections.cmake index 9758958b44f..6ec35ddcaf1 100644 --- a/cpp/cmake/thirdparty/get_cucollections.cmake +++ b/cpp/cmake/thirdparty/get_cucollections.cmake @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2021-2022, NVIDIA CORPORATION. +# Copyright (c) 2021-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -15,6 +15,11 @@ # This function finds cuCollections and performs any additional configuration. function(find_and_configure_cucollections) include(${rapids-cmake-dir}/cpm/cuco.cmake) + include(${rapids-cmake-dir}/cpm/package_override.cmake) + + set(cudf_patch_dir "${CMAKE_CURRENT_FUNCTION_LIST_DIR}/patches") + rapids_cpm_package_override("${cudf_patch_dir}/cuco_override.json") + if(BUILD_SHARED_LIBS) rapids_cpm_cuco(BUILD_EXPORT_SET cudf-exports) else() diff --git a/cpp/cmake/thirdparty/patches/cuco_noexcept.diff b/cpp/cmake/thirdparty/patches/cuco_noexcept.diff new file mode 100644 index 00000000000..0f334c0e81f --- /dev/null +++ b/cpp/cmake/thirdparty/patches/cuco_noexcept.diff @@ -0,0 +1,227 @@ +diff --git a/include/cuco/aow_storage.cuh b/include/cuco/aow_storage.cuh +index 7f9de01..5228193 100644 +--- a/include/cuco/aow_storage.cuh ++++ b/include/cuco/aow_storage.cuh +@@ -81,7 +81,7 @@ class aow_storage : public detail::aow_storage_base { + * @param size Number of windows to (de)allocate + * @param allocator Allocator used for (de)allocating device storage + */ +- explicit constexpr aow_storage(Extent size, Allocator const& allocator = {}) noexcept; ++ explicit constexpr aow_storage(Extent size, Allocator const& allocator = {}); + + aow_storage(aow_storage&&) = default; ///< Move constructor + /** +@@ -122,7 +122,7 @@ class aow_storage : public detail::aow_storage_base { + * @param key Key to which all keys in `slots` are initialized + * @param stream Stream used for executing the kernel + */ +- void initialize(value_type key, cuda_stream_ref stream = {}) noexcept; ++ void initialize(value_type key, cuda_stream_ref stream = {}); + + /** + * @brief Asynchronously initializes each slot in the AoW storage to contain `key`. +diff --git a/include/cuco/detail/open_addressing/open_addressing_impl.cuh b/include/cuco/detail/open_addressing/open_addressing_impl.cuh +index c2c9c14..8ac4236 100644 +--- a/include/cuco/detail/open_addressing/open_addressing_impl.cuh ++++ b/include/cuco/detail/open_addressing/open_addressing_impl.cuh +@@ -125,7 +125,7 @@ class open_addressing_impl { + KeyEqual const& pred, + ProbingScheme const& probing_scheme, + Allocator const& alloc, +- cuda_stream_ref stream) noexcept ++ cuda_stream_ref stream) + : empty_slot_sentinel_{empty_slot_sentinel}, + erased_key_sentinel_{this->extract_key(empty_slot_sentinel)}, + predicate_{pred}, +@@ -233,7 +233,7 @@ class open_addressing_impl { + * + * @param stream CUDA stream this operation is executed in + */ +- void clear(cuda_stream_ref stream) noexcept { storage_.initialize(empty_slot_sentinel_, stream); } ++ void clear(cuda_stream_ref stream) { storage_.initialize(empty_slot_sentinel_, stream); } + + /** + * @brief Asynchronously erases all elements from the container. After this call, `size()` returns +@@ -599,7 +599,7 @@ class open_addressing_impl { + * + * @return The number of elements in the container + */ +- [[nodiscard]] size_type size(cuda_stream_ref stream) const noexcept ++ [[nodiscard]] size_type size(cuda_stream_ref stream) const + { + auto counter = + detail::counter_storage{this->allocator()}; +diff --git a/include/cuco/detail/static_map/static_map.inl b/include/cuco/detail/static_map/static_map.inl +index e17a145..3fa1d02 100644 +--- a/include/cuco/detail/static_map/static_map.inl ++++ b/include/cuco/detail/static_map/static_map.inl +@@ -123,7 +123,7 @@ template + void static_map::clear( +- cuda_stream_ref stream) noexcept ++ cuda_stream_ref stream) + { + impl_->clear(stream); + } +@@ -215,7 +215,7 @@ template + template + void static_map:: +- insert_or_assign(InputIt first, InputIt last, cuda_stream_ref stream) noexcept ++ insert_or_assign(InputIt first, InputIt last, cuda_stream_ref stream) + { + return this->insert_or_assign_async(first, last, stream); + stream.synchronize(); +@@ -465,7 +465,7 @@ template + static_map::size_type + static_map::size( +- cuda_stream_ref stream) const noexcept ++ cuda_stream_ref stream) const + { + return impl_->size(stream); + } +diff --git a/include/cuco/detail/static_multiset/static_multiset.inl b/include/cuco/detail/static_multiset/static_multiset.inl +index 174f9bc..582926b 100644 +--- a/include/cuco/detail/static_multiset/static_multiset.inl ++++ b/include/cuco/detail/static_multiset/static_multiset.inl +@@ -97,7 +97,7 @@ template + void static_multiset::clear( +- cuda_stream_ref stream) noexcept ++ cuda_stream_ref stream) + { + impl_->clear(stream); + } +@@ -183,7 +183,7 @@ template + static_multiset::size_type + static_multiset::size( +- cuda_stream_ref stream) const noexcept ++ cuda_stream_ref stream) const + { + return impl_->size(stream); + } +diff --git a/include/cuco/detail/static_set/static_set.inl b/include/cuco/detail/static_set/static_set.inl +index 645013f..d3cece0 100644 +--- a/include/cuco/detail/static_set/static_set.inl ++++ b/include/cuco/detail/static_set/static_set.inl +@@ -98,7 +98,7 @@ template + void static_set::clear( +- cuda_stream_ref stream) noexcept ++ cuda_stream_ref stream) + { + impl_->clear(stream); + } +@@ -429,7 +429,7 @@ template + static_set::size_type + static_set::size( +- cuda_stream_ref stream) const noexcept ++ cuda_stream_ref stream) const + { + return impl_->size(stream); + } +diff --git a/include/cuco/detail/storage/aow_storage.inl b/include/cuco/detail/storage/aow_storage.inl +index 3547f4c..94b7f98 100644 +--- a/include/cuco/detail/storage/aow_storage.inl ++++ b/include/cuco/detail/storage/aow_storage.inl +@@ -32,8 +32,8 @@ + namespace cuco { + + template +-constexpr aow_storage::aow_storage( +- Extent size, Allocator const& allocator) noexcept ++constexpr aow_storage::aow_storage(Extent size, ++ Allocator const& allocator) + : detail::aow_storage_base{size}, + allocator_{allocator}, + window_deleter_{capacity(), allocator_}, +@@ -64,7 +64,7 @@ aow_storage::ref() const noexcept + + template + void aow_storage::initialize(value_type key, +- cuda_stream_ref stream) noexcept ++ cuda_stream_ref stream) + { + this->initialize_async(key, stream); + stream.synchronize(); +diff --git a/include/cuco/static_map.cuh b/include/cuco/static_map.cuh +index c86e90c..95da423 100644 +--- a/include/cuco/static_map.cuh ++++ b/include/cuco/static_map.cuh +@@ -269,7 +269,7 @@ class static_map { + * + * @param stream CUDA stream this operation is executed in + */ +- void clear(cuda_stream_ref stream = {}) noexcept; ++ void clear(cuda_stream_ref stream = {}); + + /** + * @brief Asynchronously erases all elements from the container. After this call, `size()` returns +@@ -387,7 +387,7 @@ class static_map { + * @param stream CUDA stream used for insert + */ + template +- void insert_or_assign(InputIt first, InputIt last, cuda_stream_ref stream = {}) noexcept; ++ void insert_or_assign(InputIt first, InputIt last, cuda_stream_ref stream = {}); + + /** + * @brief For any key-value pair `{k, v}` in the range `[first, last)`, if a key equivalent to `k` +@@ -690,7 +690,7 @@ class static_map { + * @param stream CUDA stream used to get the number of inserted elements + * @return The number of elements in the container + */ +- [[nodiscard]] size_type size(cuda_stream_ref stream = {}) const noexcept; ++ [[nodiscard]] size_type size(cuda_stream_ref stream = {}) const; + + /** + * @brief Gets the maximum number of elements the hash map can hold. +diff --git a/include/cuco/static_multiset.cuh b/include/cuco/static_multiset.cuh +index 0daf103..fbcbc9c 100644 +--- a/include/cuco/static_multiset.cuh ++++ b/include/cuco/static_multiset.cuh +@@ -235,7 +235,7 @@ class static_multiset { + * + * @param stream CUDA stream this operation is executed in + */ +- void clear(cuda_stream_ref stream = {}) noexcept; ++ void clear(cuda_stream_ref stream = {}); + + /** + * @brief Asynchronously erases all elements from the container. After this call, `size()` returns +@@ -339,7 +339,7 @@ class static_multiset { + * @param stream CUDA stream used to get the number of inserted elements + * @return The number of elements in the container + */ +- [[nodiscard]] size_type size(cuda_stream_ref stream = {}) const noexcept; ++ [[nodiscard]] size_type size(cuda_stream_ref stream = {}) const; + + /** + * @brief Gets the maximum number of elements the multiset can hold. +diff --git a/include/cuco/static_set.cuh b/include/cuco/static_set.cuh +index a069939..3517f84 100644 +--- a/include/cuco/static_set.cuh ++++ b/include/cuco/static_set.cuh +@@ -240,7 +240,7 @@ class static_set { + * + * @param stream CUDA stream this operation is executed in + */ +- void clear(cuda_stream_ref stream = {}) noexcept; ++ void clear(cuda_stream_ref stream = {}); + + /** + * @brief Asynchronously erases all elements from the container. After this call, `size()` returns +@@ -687,7 +687,7 @@ class static_set { + * @param stream CUDA stream used to get the number of inserted elements + * @return The number of elements in the container + */ +- [[nodiscard]] size_type size(cuda_stream_ref stream = {}) const noexcept; ++ [[nodiscard]] size_type size(cuda_stream_ref stream = {}) const; + + /** + * @brief Gets the maximum number of elements the hash set can hold. diff --git a/cpp/cmake/thirdparty/patches/cuco_override.json b/cpp/cmake/thirdparty/patches/cuco_override.json new file mode 100644 index 00000000000..ae0a9a4b4f0 --- /dev/null +++ b/cpp/cmake/thirdparty/patches/cuco_override.json @@ -0,0 +1,14 @@ + +{ + "packages" : { + "cuco" : { + "patches" : [ + { + "file" : "${current_json_dir}/cuco_noexcept.diff", + "issue" : "Remove erroneous noexcept clauses on cuco functions that may throw [https://github.com/rapidsai/cudf/issues/16059]", + "fixed_in" : "" + } + ] + } + } +} diff --git a/python/cudf_polars/cudf_polars/dsl/expr.py b/python/cudf_polars/cudf_polars/dsl/expr.py index 17d7d15e4e5..16cfd9b9749 100644 --- a/python/cudf_polars/cudf_polars/dsl/expr.py +++ b/python/cudf_polars/cudf_polars/dsl/expr.py @@ -27,11 +27,12 @@ import cudf._lib.pylibcudf as plc from cudf_polars.containers import Column, NamedColumn -from cudf_polars.utils import sorting +from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: from collections.abc import Mapping, Sequence + import polars.polars as plrs import polars.type_aliases as pl_types from cudf_polars.containers import DataFrame @@ -369,6 +370,29 @@ def do_evaluate( return Column(plc.Column.from_scalar(plc.interop.from_arrow(self.value), 1)) +class LiteralColumn(Expr): + __slots__ = ("value",) + _non_child = ("dtype", "value") + value: pa.Array[Any, Any] + children: tuple[()] + + def __init__(self, dtype: plc.DataType, value: plrs.PySeries) -> None: + super().__init__(dtype) + data = value.to_arrow() + self.value = data.cast(dtypes.downcast_arrow_lists(data.type)) + + def do_evaluate( + self, + df: DataFrame, + *, + context: ExecutionContext = ExecutionContext.FRAME, + mapping: Mapping[Expr, Column] | None = None, + ) -> Column: + """Evaluate this expression given a dataframe for context.""" + # datatype of pyarrow array is correct by construction. + return Column(plc.interop.from_arrow(self.value)) + + class Col(Expr): __slots__ = ("name",) _non_child = ("dtype", "name") @@ -1156,6 +1180,12 @@ def __init__( super().__init__(dtype) self.op = op self.children = (left, right) + if ( + op in (plc.binaryop.BinaryOperator.ADD, plc.binaryop.BinaryOperator.SUB) + and ({left.dtype.id(), right.dtype.id()}.issubset(dtypes.TIMELIKE_TYPES)) + and not dtypes.have_compatible_resolution(left.dtype.id(), right.dtype.id()) + ): + raise NotImplementedError("Casting rules for timelike types") _MAPPING: ClassVar[dict[pl_expr.Operator, plc.binaryop.BinaryOperator]] = { pl_expr.Operator.Eq: plc.binaryop.BinaryOperator.EQUAL, diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 3f5f3c74050..abe26b14a90 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -29,7 +29,7 @@ import cudf_polars.dsl.expr as expr from cudf_polars.containers import DataFrame, NamedColumn -from cudf_polars.utils import sorting +from cudf_polars.utils import dtypes, sorting if TYPE_CHECKING: from collections.abc import MutableMapping @@ -130,6 +130,11 @@ class IR: schema: Schema """Mapping from column names to their data types.""" + def __post_init__(self): + """Validate preconditions.""" + if any(dtype.id() == plc.TypeId.EMPTY for dtype in self.schema.values()): + raise NotImplementedError("Cannot make empty columns.") + def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """ Evaluate the node and return a dataframe. @@ -292,15 +297,10 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: table = pdf.to_arrow() schema = table.schema for i, field in enumerate(schema): - # TODO: Nested types - if field.type == pa.large_string(): - # TODO: goes away when libcudf supports large strings - schema = schema.set(i, pa.field(field.name, pa.string())) - elif isinstance(field.type, pa.LargeListType): - # TODO: goes away when libcudf supports large lists - schema = schema.set( - i, pa.field(field.name, pa.list_(field.type.field(0))) - ) + schema = schema.set( + i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type)) + ) + # No-op if the schema is unchanged. table = table.cast(schema) df = DataFrame.from_table( plc.interop.from_arrow(table), list(self.schema.keys()) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 953ff636cce..f4bf07ae1e0 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -12,6 +12,7 @@ import pyarrow as pa from typing_extensions import assert_never +import polars.polars as plrs from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir import cudf._lib.pylibcudf as plc @@ -383,6 +384,8 @@ def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr @_translate_expr.register def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr: + if isinstance(node.value, plrs.PySeries): + return expr.LiteralColumn(dtype, node.value) value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype)) return expr.Literal(dtype, value) diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 3d4a643e1fc..507acb5d33a 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -7,13 +7,92 @@ from functools import cache +import pyarrow as pa from typing_extensions import assert_never import polars as pl import cudf._lib.pylibcudf as plc -__all__ = ["from_polars"] +__all__ = ["from_polars", "downcast_arrow_lists", "have_compatible_resolution"] + + +TIMELIKE_TYPES: frozenset[plc.TypeId] = frozenset( + [ + plc.TypeId.TIMESTAMP_MILLISECONDS, + plc.TypeId.TIMESTAMP_MICROSECONDS, + plc.TypeId.TIMESTAMP_NANOSECONDS, + plc.TypeId.TIMESTAMP_DAYS, + plc.TypeId.DURATION_MILLISECONDS, + plc.TypeId.DURATION_MICROSECONDS, + plc.TypeId.DURATION_NANOSECONDS, + ] +) + + +def have_compatible_resolution(lid: plc.TypeId, rid: plc.TypeId): + """ + Do two datetime typeids have matching resolution for a binop. + + Parameters + ---------- + lid + Left type id + rid + Right type id + + Returns + ------- + True if resolutions are compatible, False otherwise. + + Notes + ----- + Polars has different casting rules for combining + datetimes/durations than libcudf, and while we don't encode the + casting rules fully, just reject things we can't handle. + + Precondition for correctness: both lid and rid are timelike. + """ + if lid == rid: + return True + # Timestamps are smaller than durations in the libcudf enum. + lid, rid = sorted([lid, rid]) + if lid == plc.TypeId.TIMESTAMP_MILLISECONDS: + return rid == plc.TypeId.DURATION_MILLISECONDS + elif lid == plc.TypeId.TIMESTAMP_MICROSECONDS: + return rid == plc.TypeId.DURATION_MICROSECONDS + elif lid == plc.TypeId.TIMESTAMP_NANOSECONDS: + return rid == plc.TypeId.DURATION_NANOSECONDS + return False + + +def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: + """ + Sanitize an arrow datatype from polars. + + Parameters + ---------- + typ + Arrow type to sanitize + + Returns + ------- + Sanitized arrow type + + Notes + ----- + As well as arrow ``ListType``s, polars can produce + ``LargeListType``s and ``FixedSizeListType``s, these are not + currently handled by libcudf, so we attempt to cast them all into + normal ``ListType``s on the arrow side before consuming the arrow + data. + """ + if isinstance(typ, pa.LargeListType): + return pa.list_(downcast_arrow_lists(typ.value_type)) + # We don't have to worry about diving into struct types for now + # since those are always NotImplemented before we get here. + assert not isinstance(typ, pa.StructType) + return typ @cache diff --git a/python/cudf_polars/tests/expressions/test_literal.py b/python/cudf_polars/tests/expressions/test_literal.py new file mode 100644 index 00000000000..55e688428bd --- /dev/null +++ b/python/cudf_polars/tests/expressions/test_literal.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import pytest + +import polars as pl + +from cudf_polars.testing.asserts import ( + assert_gpu_result_equal, + assert_ir_translation_raises, +) +from cudf_polars.utils import dtypes + + +@pytest.fixture( + params=[ + None, + pl.Int8(), + pl.Int16(), + pl.Int32(), + pl.Int64(), + pl.UInt8(), + pl.UInt16(), + pl.UInt32(), + pl.UInt64(), + ] +) +def integer(request): + return pl.lit(10, dtype=request.param) + + +@pytest.fixture(params=[None, pl.Float32(), pl.Float64()]) +def float(request): + return pl.lit(1.0, dtype=request.param) + + +def test_numeric_literal(integer, float): + df = pl.LazyFrame({}) + + q = df.select(integer=integer, float_=float, sum_=integer + float) + + assert_gpu_result_equal(q) + + +@pytest.fixture( + params=[pl.Date(), pl.Datetime("ms"), pl.Datetime("us"), pl.Datetime("ns")] +) +def timestamp(request): + return pl.lit(10_000, dtype=request.param) + + +@pytest.fixture(params=[pl.Duration("ms"), pl.Duration("us"), pl.Duration("ns")]) +def timedelta(request): + return pl.lit(9_000, dtype=request.param) + + +def test_timelike_literal(timestamp, timedelta): + df = pl.LazyFrame({}) + + q = df.select( + time=timestamp, + delta=timedelta, + adjusted=timestamp + timedelta, + two_delta=timedelta + timedelta, + ) + schema = q.collect_schema() + time_type = schema["time"] + delta_type = schema["delta"] + if dtypes.have_compatible_resolution( + dtypes.from_polars(time_type).id(), dtypes.from_polars(delta_type).id() + ): + assert_gpu_result_equal(q) + else: + assert_ir_translation_raises(q, NotImplementedError) + + +def test_select_literal_series(): + df = pl.LazyFrame({}) + + q = df.select( + a=pl.Series(["a", "b", "c"], dtype=pl.String()), + b=pl.Series([[1, 2], [3], None], dtype=pl.List(pl.UInt16())), + c=pl.Series([[[1]], [], [[1, 2, 3, 4]]], dtype=pl.List(pl.List(pl.Float32()))), + ) + + assert_gpu_result_equal(q) + + +@pytest.mark.parametrize("expr", [pl.lit(None), pl.lit(10, dtype=pl.Decimal())]) +def test_unsupported_literal_raises(expr): + df = pl.LazyFrame({}) + + q = df.select(expr) + + assert_ir_translation_raises(q, NotImplementedError) diff --git a/python/cudf_polars/tests/test_dataframescan.py b/python/cudf_polars/tests/test_dataframescan.py index 1ffe06ac562..b5c0fb7be9f 100644 --- a/python/cudf_polars/tests/test_dataframescan.py +++ b/python/cudf_polars/tests/test_dataframescan.py @@ -41,3 +41,22 @@ def test_scan_drop_nulls(subset, predicate_pushdown): assert_gpu_result_equal( q, collect_kwargs={"predicate_pushdown": predicate_pushdown} ) + + +def test_can_convert_lists(): + df = pl.LazyFrame( + { + "a": pl.Series([[1, 2], [3]], dtype=pl.List(pl.Int8())), + "b": pl.Series([[1], [2]], dtype=pl.List(pl.UInt16())), + "c": pl.Series( + [ + [["1", "2", "3"], ["4", "567"]], + [["8", "9"], []], + ], + dtype=pl.List(pl.List(pl.String())), + ), + "d": pl.Series([[[1, 2]], []], dtype=pl.List(pl.List(pl.UInt16()))), + } + ) + + assert_gpu_result_equal(df)