Skip to content

Commit

Permalink
Fix dask token normalization (#14829)
Browse files Browse the repository at this point in the history
This PR fixes cudf's `__dask_tokenization__` definitions so that they will produce data that can be deterministically tokenized when a `MultiIndex` is present. I ran into this problem in dask-expr for an index with datetime data (a case reflected by the new test).

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Charles Blackmon-Luca (https://github.com/charlesbluca)

URL: #14829
  • Loading branch information
rjzamora authored Jan 31, 2024
1 parent 6ed75ff commit bb59715
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 6 deletions.
6 changes: 4 additions & 2 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1954,10 +1954,12 @@ def _repeat(
@_cudf_nvtx_annotate
@_warn_no_dask_cudf
def __dask_tokenize__(self):
from dask.base import normalize_token

return [
type(self),
self._dtypes,
self.to_pandas(),
normalize_token(self._dtypes),
normalize_token(self.to_pandas()),
]


Expand Down
8 changes: 5 additions & 3 deletions python/cudf/cudf/core/indexed_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6176,11 +6176,13 @@ def convert_dtypes(

@_warn_no_dask_cudf
def __dask_tokenize__(self):
from dask.base import normalize_token

return [
type(self),
self._dtypes,
self.index,
self.hash_values().values_host,
normalize_token(self._dtypes),
normalize_token(self.index),
normalize_token(self.hash_values().values_host),
]


Expand Down
14 changes: 13 additions & 1 deletion python/dask_cudf/dask_cudf/tests/test_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION.
# Copyright (c) 2021-2024, NVIDIA CORPORATION.

from datetime import datetime

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -82,6 +84,16 @@ def test_deterministic_tokenize(index):
assert tokenize(df2) == tokenize(df2)


def test_deterministic_tokenize_multiindex():
dt = datetime.strptime("1995-03-15", "%Y-%m-%d")
index = cudf.MultiIndex(
levels=[[1, 2], [dt]],
codes=[[0, 1], [0, 0]],
)
df = cudf.DataFrame(index=index)
assert tokenize(df) == tokenize(df)


@pytest.mark.parametrize("preserve_index", [True, False])
def test_pyarrow_schema_dispatch(preserve_index):
from dask.dataframe.dispatch import (
Expand Down

0 comments on commit bb59715

Please sign in to comment.