Skip to content

Commit

Permalink
cudf-polars string/numeric casting (#17076)
Browse files Browse the repository at this point in the history
Depends on #16991
Part of #17060

Implements cross casting from string <-> numeric types in `cudf-polars`

Authors:
  - https://github.com/brandon-b-miller
  - Matthew Murray (https://github.com/Matt711)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)
  - Muhammad Haseeb (https://github.com/mhaseeb123)
  - Matthew Murray (https://github.com/Matt711)

URL: #17076
  • Loading branch information
brandon-b-miller authored Nov 7, 2024
1 parent 4cbc15a commit e4c52dd
Show file tree
Hide file tree
Showing 13 changed files with 311 additions and 24 deletions.
24 changes: 24 additions & 0 deletions cpp/include/cudf/utilities/traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,30 @@ constexpr inline bool is_integral_not_bool()
*/
bool is_integral_not_bool(data_type type);

/**
* @brief Indicates whether the type `T` is a numeric type but not bool type.
*
* @tparam T The type to verify
* @return true `T` is numeric but not bool
* @return false `T` is not numeric or is bool
*/
template <typename T>
constexpr inline bool is_numeric_not_bool()
{
return cudf::is_numeric<T>() and not std::is_same_v<T, bool>;
}

/**
* @brief Indicates whether `type` is a numeric `data_type` but not BOOL8
*
* "Numeric" types are integral/floating point types such as `INT*` or `FLOAT*`.
*
* @param type The `data_type` to verify
* @return true `type` is numeric but not bool
* @return false `type` is not numeric or is bool
*/
bool is_numeric_not_bool(data_type type);

/**
* @brief Indicates whether the type `T` is a floating point type.
*
Expand Down
13 changes: 13 additions & 0 deletions cpp/src/utilities/traits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ bool is_integral_not_bool(data_type type)
return cudf::type_dispatcher(type, is_integral_not_bool_impl{});
}

struct is_numeric_not_bool_impl {
template <typename T>
constexpr bool operator()()
{
return is_numeric_not_bool<T>();
}
};

bool is_numeric_not_bool(data_type type)
{
return cudf::type_dispatcher(type, is_numeric_not_bool_impl{});
}

struct is_floating_point_impl {
template <typename T>
constexpr bool operator()()
Expand Down
56 changes: 51 additions & 5 deletions python/cudf_polars/cudf_polars/containers/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@
import functools
from typing import TYPE_CHECKING

from polars.exceptions import InvalidOperationError

import pylibcudf as plc
from pylibcudf.strings.convert.convert_floats import from_floats, is_float, to_floats
from pylibcudf.strings.convert.convert_integers import (
from_integers,
is_integer,
to_integers,
)
from pylibcudf.traits import is_floating_point

from cudf_polars.utils.dtypes import is_order_preserving_cast

if TYPE_CHECKING:
from typing_extensions import Self
Expand Down Expand Up @@ -129,11 +140,46 @@ def astype(self, dtype: plc.DataType) -> Column:
This only produces a copy if the requested dtype doesn't match
the current one.
"""
if self.obj.type() != dtype:
return Column(plc.unary.cast(self.obj, dtype), name=self.name).sorted_like(
self
)
return self
if self.obj.type() == dtype:
return self

if dtype.id() == plc.TypeId.STRING or self.obj.type().id() == plc.TypeId.STRING:
return Column(self._handle_string_cast(dtype))
else:
result = Column(plc.unary.cast(self.obj, dtype))
if is_order_preserving_cast(self.obj.type(), dtype):
return result.sorted_like(self)
return result

def _handle_string_cast(self, dtype: plc.DataType) -> plc.Column:
if dtype.id() == plc.TypeId.STRING:
if is_floating_point(self.obj.type()):
return from_floats(self.obj)
else:
return from_integers(self.obj)
else:
if is_floating_point(dtype):
floats = is_float(self.obj)
if not plc.interop.to_arrow(
plc.reduce.reduce(
floats,
plc.aggregation.all(),
plc.DataType(plc.TypeId.BOOL8),
)
).as_py():
raise InvalidOperationError("Conversion from `str` failed.")
return to_floats(self.obj, dtype)
else:
integers = is_integer(self.obj)
if not plc.interop.to_arrow(
plc.reduce.reduce(
integers,
plc.aggregation.all(),
plc.DataType(plc.TypeId.BOOL8),
)
).as_py():
raise InvalidOperationError("Conversion from `str` failed.")
return to_integers(self.obj, dtype)

def copy_metadata(self, from_: pl.Series, /) -> Self:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, dtype: plc.DataType, value: Expr) -> None:
self.children = (value,)
if not dtypes.can_cast(value.dtype, self.dtype):
raise NotImplementedError(
f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}"
f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}"
)

def do_evaluate(
Expand All @@ -48,7 +48,7 @@ def do_evaluate(
"""Evaluate this expression given a dataframe for context."""
(child,) = self.children
column = child.evaluate(df, context=context, mapping=mapping)
return Column(plc.unary.cast(column.obj, self.dtype)).sorted_like(column)
return column.astype(self.dtype)

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
Expand Down
1 change: 1 addition & 0 deletions python/cudf_polars/cudf_polars/testing/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def pytest_configure(config: pytest.Config) -> None:
"tests/unit/sql/test_cast.py::test_cast_errors[values0-values::uint8-conversion from `f64` to `u64` failed]": "Casting that raises not supported on GPU",
"tests/unit/sql/test_cast.py::test_cast_errors[values1-values::uint4-conversion from `i64` to `u32` failed]": "Casting that raises not supported on GPU",
"tests/unit/sql/test_cast.py::test_cast_errors[values2-values::int1-conversion from `i64` to `i8` failed]": "Casting that raises not supported on GPU",
"tests/unit/sql/test_cast.py::test_cast_errors[values5-values::int4-conversion from `str` to `i32` failed]": "Cast raises, but error user receives is wrong",
"tests/unit/sql/test_miscellaneous.py::test_read_csv": "Incorrect handling of missing_is_null in read_csv",
"tests/unit/sql/test_wildcard_opts.py::test_select_wildcard_errors": "Raises correctly but with different exception",
"tests/unit/streaming/test_streaming_io.py::test_parquet_eq_statistics": "Debug output on stderr doesn't match",
Expand Down
65 changes: 60 additions & 5 deletions python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@

import polars as pl

from pylibcudf.traits import (
is_floating_point,
is_integral_not_bool,
is_numeric_not_bool,
)

__all__ = [
"from_polars",
"downcast_arrow_lists",
"can_cast",
"is_order_preserving_cast",
]
import pylibcudf as plc

__all__ = ["from_polars", "downcast_arrow_lists", "can_cast"]


def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType:
"""
Expand Down Expand Up @@ -62,9 +72,54 @@ def can_cast(from_: plc.DataType, to: plc.DataType) -> bool:
True if casting is supported, False otherwise
"""
return (
plc.traits.is_fixed_width(to)
and plc.traits.is_fixed_width(from_)
and plc.unary.is_supported_cast(from_, to)
(
plc.traits.is_fixed_width(to)
and plc.traits.is_fixed_width(from_)
and plc.unary.is_supported_cast(from_, to)
)
or (from_.id() == plc.TypeId.STRING and is_numeric_not_bool(to))
or (to.id() == plc.TypeId.STRING and is_numeric_not_bool(from_))
)


def is_order_preserving_cast(from_: plc.DataType, to: plc.DataType) -> bool:
"""
Determine if a cast would preserve the order of the source data.
Parameters
----------
from_
Source datatype
to
Target datatype
Returns
-------
True if the cast is order-preserving, False otherwise
"""
if from_.id() == to.id():
return True

if is_integral_not_bool(from_) and is_integral_not_bool(to):
# True if signedness is the same and the target is larger
if plc.traits.is_unsigned(from_) == plc.traits.is_unsigned(to):
if plc.types.size_of(to) >= plc.types.size_of(from_):
return True
elif (plc.traits.is_unsigned(from_) and not plc.traits.is_unsigned(to)) and (
plc.types.size_of(to) > plc.types.size_of(from_)
):
# Unsigned to signed is order preserving if target is large enough
# But signed to unsigned is never order preserving due to negative values
return True
elif (
is_floating_point(from_)
and is_floating_point(to)
and (plc.types.size_of(to) >= plc.types.size_of(from_))
):
# True if the target is larger
return True
return (is_integral_not_bool(from_) and is_floating_point(to)) or (
is_floating_point(from_) and is_integral_not_bool(to)
)


Expand Down
2 changes: 1 addition & 1 deletion python/cudf_polars/tests/expressions/test_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_supported_dtypes = [(pl.Int8(), pl.Int64())]

_unsupported_dtypes = [
(pl.String(), pl.Int64()),
(pl.Datetime("ns"), pl.Int64()),
]


Expand Down
10 changes: 0 additions & 10 deletions python/cudf_polars/tests/expressions/test_numeric_binops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from cudf_polars.testing.asserts import (
assert_gpu_result_equal,
assert_ir_translation_raises,
)

dtypes = [
Expand Down Expand Up @@ -114,12 +113,3 @@ def test_binop_with_scalar(left_scalar, right_scalar):
q = df.select(lop / rop)

assert_gpu_result_equal(q)


def test_numeric_to_string_cast_fails():
df = pl.DataFrame(
{"a": [1, 1, 2, 3, 3, 4, 1], "b": [None, 2, 3, 4, 5, 6, 7]}
).lazy()
q = df.select(pl.col("a").cast(pl.String))

assert_ir_translation_raises(q, NotImplementedError)
117 changes: 117 additions & 0 deletions python/cudf_polars/tests/expressions/test_stringfunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,79 @@ def ldf(with_nulls):
)


@pytest.fixture(params=[pl.Int8, pl.Int16, pl.Int32, pl.Int64])
def integer_type(request):
return request.param


@pytest.fixture(params=[pl.Float32, pl.Float64])
def floating_type(request):
return request.param


@pytest.fixture(params=[pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.Float32, pl.Float64])
def numeric_type(request):
return request.param


@pytest.fixture
def str_to_integer_data(with_nulls):
a = ["1", "2", "3", "4", "5", "6"]
if with_nulls:
a[4] = None
return pl.LazyFrame({"a": a})


@pytest.fixture
def str_to_float_data(with_nulls):
a = [
"1.1",
"2.2",
"3.3",
"4.4",
"5.5",
"6.6",
"inf",
"+inf",
"-inf",
"Inf",
"-Inf",
"nan",
"-1.234",
"2e2",
]
if with_nulls:
a[4] = None
return pl.LazyFrame({"a": a})


@pytest.fixture
def str_from_integer_data(with_nulls, integer_type):
a = [1, 2, 3, 4, 5, 6]
if with_nulls:
a[4] = None
return pl.LazyFrame({"a": pl.Series(a, dtype=integer_type)})


@pytest.fixture
def str_from_float_data(with_nulls, floating_type):
a = [
1.1,
2.2,
3.3,
4.4,
5.5,
6.6,
float("inf"),
float("+inf"),
float("-inf"),
float("nan"),
]
if with_nulls:
a[4] = None
return pl.LazyFrame({"a": pl.Series(a, dtype=floating_type)})


slice_cases = [
(1, 3),
(0, 3),
Expand Down Expand Up @@ -337,3 +410,47 @@ def test_unsupported_regex_raises(pattern):

q = df.select(pl.col("a").str.contains(pattern, strict=True))
assert_ir_translation_raises(q, NotImplementedError)


def test_string_to_integer(str_to_integer_data, integer_type):
query = str_to_integer_data.select(pl.col("a").cast(integer_type))
assert_gpu_result_equal(query)


def test_string_from_integer(str_from_integer_data):
query = str_from_integer_data.select(pl.col("a").cast(pl.String))
assert_gpu_result_equal(query)


def test_string_to_float(str_to_float_data, floating_type):
query = str_to_float_data.select(pl.col("a").cast(floating_type))
assert_gpu_result_equal(query)


def test_string_from_float(request, str_from_float_data):
if str_from_float_data.collect_schema()["a"] == pl.Float32:
# libcudf will return a string representing the precision out to
# a certain number of hardcoded decimal places. This results in
# the fractional part being thrown away which causes discrepancies
# for certain numbers. For instance, the float32 representation of
# 1.1 is 1.100000023841858. When cast to a string, this will become
# 1.100000024. But the float64 representation of 1.1 is
# 1.1000000000000000888 which will result in libcudf truncating the
# final value to 1.1.
request.applymarker(pytest.mark.xfail(reason="libcudf truncation"))
query = str_from_float_data.select(pl.col("a").cast(pl.String))

# libcudf reads float('inf') -> "inf"
# but polars reads float('inf') -> "Inf"
query = query.select(pl.col("a").str.to_lowercase())
assert_gpu_result_equal(query)


def test_string_to_numeric_invalid(numeric_type):
df = pl.LazyFrame({"a": ["a", "b", "c"]})
q = df.select(pl.col("a").cast(numeric_type))
assert_collect_raises(
q,
polars_except=pl.exceptions.InvalidOperationError,
cudf_except=pl.exceptions.ComputeError,
)
Loading

0 comments on commit e4c52dd

Please sign in to comment.