Skip to content

Commit

Permalink
Add support for make_meta_obj dispatch in dask-cudf (#8342)
Browse files Browse the repository at this point in the history
Fixes: #7946 

This PR is dependent on upstream dask changes that are needed for a portion of the fix: https://github.com/dask/dask/pull/7586/files

This PR includes changes to introduce `make_meta_obj` which will ensure proper metadata is retrieved from the parent_meta being passed in the upstream PR. 

Authors:
  - GALI PREM SAGAR (https://github.com/galipremsagar)

Approvers:
  - jakirkham (https://github.com/jakirkham)
  - Keith Kraus (https://github.com/kkraus14)

URL: #8342
  • Loading branch information
galipremsagar authored May 26, 2021
1 parent 2383193 commit cbbcba7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
10 changes: 6 additions & 4 deletions python/dask_cudf/dask_cudf/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pyarrow as pa

from dask.dataframe.categorical import categorical_dtype_dispatch
from dask.dataframe.core import get_parallel_type, make_meta, meta_nonempty
from dask.dataframe.core import get_parallel_type, meta_nonempty
from dask.dataframe.methods import (
concat_dispatch,
is_categorical_dtype_dispatch,
Expand All @@ -18,6 +18,8 @@
_scalar_from_dtype,
is_arraylike,
is_scalar,
make_meta,
make_meta_obj,
)

import cudf
Expand Down Expand Up @@ -133,8 +135,8 @@ def _empty_series(name, dtype, index=None):
return cudf.Series([], dtype=dtype, name=name, index=index)


@make_meta.register(object)
def make_meta_object(x, index=None):
@make_meta_obj.register(object)
def make_meta_object_cudf(x, index=None):
"""Create an empty cudf object containing the desired metadata.
Parameters
Expand Down Expand Up @@ -244,7 +246,7 @@ def is_categorical_dtype_cudf(obj):

try:

from dask.dataframe.utils import group_split_dispatch, hash_object_dispatch
from dask.dataframe.core import group_split_dispatch, hash_object_dispatch

def safe_hash(frame):
index = frame.index
Expand Down
6 changes: 3 additions & 3 deletions python/dask_cudf/dask_cudf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, dsk, name, meta, divisions):
dsk = HighLevelGraph.from_collections(name, dsk, dependencies=[])
self.dask = dsk
self._name = name
meta = dd.core.make_meta(meta)
meta = dd.utils.make_meta_util(meta)
if not isinstance(meta, self._partition_type):
raise TypeError(
f"Expected meta to specify type "
Expand Down Expand Up @@ -115,7 +115,7 @@ def assigner(df, k, v):
out[k] = v
return out

meta = assigner(self._meta, k, dd.core.make_meta(v))
meta = assigner(self._meta, k, dd.utils.make_meta_util(v))
return self.map_partitions(assigner, k, v, meta=meta)

def apply_rows(self, func, incols, outcols, kwargs=None, cache_key=None):
Expand Down Expand Up @@ -677,7 +677,7 @@ def reduction(
if meta is None:
meta_chunk = _emulate(apply, chunk, args, chunk_kwargs)
meta = _emulate(apply, aggregate, [[meta_chunk]], aggregate_kwargs)
meta = dd.core.make_meta(meta)
meta = dd.utils.make_meta_util(meta)

graph = HighLevelGraph.from_collections(b, dsk, dependencies=args)
return dd.core.new_dd_object(graph, b, meta, (None, None))
Expand Down
28 changes: 23 additions & 5 deletions python/dask_cudf/dask_cudf/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import dask
from dask import dataframe as dd
from dask.dataframe.core import make_meta, meta_nonempty
from dask.dataframe.core import meta_nonempty
from dask.dataframe.utils import make_meta_util
from dask.utils import M

import cudf
Expand Down Expand Up @@ -585,20 +586,20 @@ def test_hash_object_dispatch(index):
)

# DataFrame
result = dd.utils.hash_object_dispatch(obj, index=index)
result = dd.core.hash_object_dispatch(obj, index=index)
expected = dgd.backends.hash_object_cudf(obj, index=index)
assert isinstance(result, cudf.Series)
dd.assert_eq(result, expected)

# Series
result = dd.utils.hash_object_dispatch(obj["x"], index=index)
result = dd.core.hash_object_dispatch(obj["x"], index=index)
expected = dgd.backends.hash_object_cudf(obj["x"], index=index)
assert isinstance(result, cudf.Series)
dd.assert_eq(result, expected)

# DataFrame with MultiIndex
obj_multi = obj.set_index(["x", "z"], drop=True)
result = dd.utils.hash_object_dispatch(obj_multi, index=index)
result = dd.core.hash_object_dispatch(obj_multi, index=index)
expected = dgd.backends.hash_object_cudf(obj_multi, index=index)
assert isinstance(result, cudf.Series)
dd.assert_eq(result, expected)
Expand Down Expand Up @@ -638,7 +639,7 @@ def test_make_meta_backends(index):
df = df.set_index(index)

# Check "empty" metadata types
chk_meta = make_meta(df)
chk_meta = make_meta_util(df)
dd.assert_eq(chk_meta.dtypes, df.dtypes)

# Check "non-empty" metadata types
Expand Down Expand Up @@ -777,3 +778,20 @@ def test_index_map_partitions():
mins_gd = gddf.index.map_partitions(M.min, meta=gddf.index).compute()

dd.assert_eq(mins_pd, mins_gd)


def test_correct_meta():
# Need these local imports in this specific order.
# For context: https://github.com/rapidsai/cudf/issues/7946
import pandas as pd

from dask import dataframe as dd

import dask_cudf # noqa: F401

df = pd.DataFrame({"a": [3, 4], "b": [1, 2]})
ddf = dd.from_pandas(df, npartitions=1)
emb = ddf["a"].apply(pd.Series, meta={"c0": "int64", "c1": "int64"})

assert isinstance(emb, dd.DataFrame)
assert isinstance(emb._meta, pd.DataFrame)

0 comments on commit cbbcba7

Please sign in to comment.