Skip to content

Commit

Permalink
Fix categorical-accessor support and testing in dask-cudf (#15591)
Browse files Browse the repository at this point in the history
Related to #15027

Adds a minor tokenization fix, and adjusts testing for categorical-accessor support.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Charles Blackmon-Luca (https://github.com/charlesbluca)
  - Matthew Roeschke (https://github.com/mroeschke)
  - Bradley Dice (https://github.com/bdice)

URL: #15591
  • Loading branch information
rjzamora authored May 1, 2024
1 parent fe4b92c commit 67d427d
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 18 deletions.
7 changes: 6 additions & 1 deletion python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6308,7 +6308,12 @@ def __dask_tokenize__(self):

return [
type(self),
normalize_token(self._dtypes),
str(self._dtypes),
*[
normalize_token(cat.categories)
for cat in self._dtypes.values()
if cat == "category"
],
normalize_token(self.index),
normalize_token(self.hash_values().values_host),
]
Expand Down
4 changes: 2 additions & 2 deletions python/dask_cudf/dask_cudf/io/tests/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support for dask_expr<1.0.6
pytestmark = skip_dask_expr(lt_version="1.0.6")
# No dask-expr support for dask<2024.4.0
pytestmark = skip_dask_expr(lt_version="2024.4.0")


def test_read_json_backend_dispatch(tmp_path):
Expand Down
4 changes: 2 additions & 2 deletions python/dask_cudf/dask_cudf/io/tests/test_orc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support for dask_expr<1.0.6
pytestmark = skip_dask_expr(lt_version="1.0.6")
# No dask-expr support for dask<2024.4.0
pytestmark = skip_dask_expr(lt_version="2024.4.0")

cur_dir = os.path.dirname(__file__)
sample_orc = os.path.join(cur_dir, "data/orc/sample.orc")
Expand Down
2 changes: 1 addition & 1 deletion python/dask_cudf/dask_cudf/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def test_check_file_size(tmpdir):
dask_cudf.io.read_parquet(fn, check_file_size=1).compute()


@xfail_dask_expr("HivePartitioning cannot be hashed", lt_version="1.0")
@xfail_dask_expr("HivePartitioning cannot be hashed", lt_version="2024.3.0")
def test_null_partition(tmpdir):
import pyarrow as pa
from pyarrow.dataset import HivePartitioning
Expand Down
4 changes: 2 additions & 2 deletions python/dask_cudf/dask_cudf/io/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import dask_cudf
from dask_cudf.tests.utils import skip_dask_expr

# No dask-expr support for dask_expr<1.0.6
pytestmark = skip_dask_expr(lt_version="1.0.6")
# No dask-expr support for dask<2024.4.0
pytestmark = skip_dask_expr(lt_version="2024.4.0")

cur_dir = os.path.dirname(__file__)
text_file = os.path.join(cur_dir, "data/text/sample.pgn")
Expand Down
18 changes: 14 additions & 4 deletions python/dask_cudf/dask_cudf/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_categorical_accessor_initialization2(data):
dsr.cat


@xfail_dask_expr("TODO: Unexplained dask-expr failure")
@xfail_dask_expr(lt_version="2024.5.0")
@pytest.mark.parametrize("data", [data_cat_1()])
def test_categorical_basic(data):
cat = data.copy()
Expand Down Expand Up @@ -203,7 +203,6 @@ def test_categorical_compare_unordered(data):
dsr < dsr


@xfail_dask_expr("TODO: Unexplained dask-expr failure")
@pytest.mark.parametrize("data", [data_cat_3()])
def test_categorical_compare_ordered(data):
cat1 = data[0].copy()
Expand Down Expand Up @@ -274,7 +273,6 @@ def test_categorical_categories():
)


@xfail_dask_expr("TODO: Unexplained dask-expr failure")
def test_categorical_as_known():
df = dask_cudf.from_cudf(DataFrame({"col_1": [0, 1, 2, 3]}), npartitions=2)
df["col_1"] = df["col_1"].astype("category")
Expand All @@ -283,7 +281,19 @@ def test_categorical_as_known():
pdf = dd.from_pandas(pd.DataFrame({"col_1": [0, 1, 2, 3]}), npartitions=2)
pdf["col_1"] = pdf["col_1"].astype("category")
expected = pdf["col_1"].cat.as_known()
dd.assert_eq(expected, actual)

# Note: Categories may be ordered differently in
# cudf and pandas. Therefore, we need to compare
# the global set of categories (before and after
# calling `compute`), then we need to check that
# the initial order of rows was preserved.
assert set(expected.cat.categories) == set(
actual.cat.categories.values_host
)
assert set(expected.compute().cat.categories) == set(
actual.compute().cat.categories.values_host
)
dd.assert_eq(expected, actual.astype(expected.dtype))


def test_str_slice():
Expand Down
11 changes: 5 additions & 6 deletions python/dask_cudf/dask_cudf/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,17 @@
import pytest
from packaging.version import Version

import dask
import dask.dataframe as dd

import cudf

from dask_cudf.expr import QUERY_PLANNING_ON

if QUERY_PLANNING_ON:
import dask_expr

DASK_EXPR_VERSION = Version(dask_expr.__version__)
DASK_VERSION = Version(dask.__version__)
else:
DASK_EXPR_VERSION = None
DASK_VERSION = None


def _make_random_frame(nelem, npartitions=2, include_na=False):
Expand All @@ -37,15 +36,15 @@ def _make_random_frame(nelem, npartitions=2, include_na=False):

def skip_dask_expr(reason=_default_reason, lt_version=None):
if lt_version is not None:
skip = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version)
skip = QUERY_PLANNING_ON and DASK_VERSION < Version(lt_version)
else:
skip = QUERY_PLANNING_ON
return pytest.mark.skipif(skip, reason=reason)


def xfail_dask_expr(reason=_default_reason, lt_version=None):
if lt_version is not None:
xfail = QUERY_PLANNING_ON and DASK_EXPR_VERSION < Version(lt_version)
xfail = QUERY_PLANNING_ON and DASK_VERSION < Version(lt_version)
else:
xfail = QUERY_PLANNING_ON
return pytest.mark.xfail(xfail, reason=reason)

0 comments on commit 67d427d

Please sign in to comment.