Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic backend dispatching to dask-expr #728

Merged
merged 18 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dask_expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dask_expr import _version, datasets
from dask_expr._collection import *
from dask_expr._dispatch import get_collection_type
from dask_expr._dummies import get_dummies
from dask_expr.io._delayed import from_delayed
from dask_expr.io.bag import to_bag
Expand Down
77 changes: 77 additions & 0 deletions dask_expr/_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

import pandas as pd
from dask.backends import CreationDispatch
from dask.dataframe.backends import DataFrameBackendEntrypoint

from dask_expr._dispatch import get_collection_type

dataframe_creation_dispatch = CreationDispatch(
module_name="dataframe",
default="pandas",
entrypoint_root="dask_expr",
entrypoint_class=DataFrameBackendEntrypoint,
name="dataframe_creation_dispatch",
)


class PandasBackendEntrypoint(DataFrameBackendEntrypoint):
"""Pandas-Backend Entrypoint Class for Dask-Expressions

Note that all DataFrame-creation functions are defined
and registered 'in-place'.
"""

@classmethod
def to_backend_dispatch(cls):
from dask.dataframe.dispatch import to_pandas_dispatch

return to_pandas_dispatch

@classmethod
def to_backend(cls, data, **kwargs):
if isinstance(data._meta, (pd.DataFrame, pd.Series, pd.Index)):
# Already a pandas-backed collection
return data
return data.map_partitions(cls.to_backend_dispatch(), **kwargs)


dataframe_creation_dispatch.register_backend("pandas", PandasBackendEntrypoint())


@get_collection_type.register(pd.Series)
def get_collection_type_series(_):
from dask_expr._collection import Series

return Series


@get_collection_type.register(pd.DataFrame)
def get_collection_type_dataframe(_):
from dask_expr._collection import DataFrame

return DataFrame


@get_collection_type.register(pd.Index)
def get_collection_type_index(_):
from dask_expr._collection import Index

return Index


@get_collection_type.register(object)
def get_collection_type_object(_):
from dask_expr._collection import Scalar

return Scalar


######################################
# cuDF: Pandas Dataframes on the GPU #
######################################


@get_collection_type.register_lazy("cudf")
def _register_cudf():
import dask_cudf # noqa: F401
26 changes: 13 additions & 13 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
has_parallel_type,
is_arraylike,
is_dataframe_like,
is_index_like,
is_series_like,
meta_warning,
new_dd_object,
Expand Down Expand Up @@ -57,12 +56,15 @@
from pandas.api.types import is_timedelta64_dtype
from tlz import first

import dask_expr._backends # noqa: F401
from dask_expr import _expr as expr
from dask_expr._align import AlignPartitions
from dask_expr._backends import dataframe_creation_dispatch
from dask_expr._categorical import CategoricalAccessor, Categorize, GetCategories
from dask_expr._concat import Concat
from dask_expr._datetime import DatetimeAccessor
from dask_expr._describe import DescribeNonNumeric, DescribeNumeric
from dask_expr._dispatch import get_collection_type
from dask_expr._expr import (
BFill,
Diff,
Expand Down Expand Up @@ -1578,9 +1580,14 @@ def to_backend(self, backend: str | None = None, **kwargs):
-------
DataFrame, Series or Index
"""
from dask.dataframe.io import to_backend
from dask_expr._backends import dataframe_creation_dispatch

return to_backend(self.to_dask_dataframe(), backend=backend, **kwargs)
phofl marked this conversation as resolved.
Show resolved Hide resolved
# Get desired backend
backend = backend or dataframe_creation_dispatch.backend
# Check that "backend" has a registered entrypoint
backend_entrypoint = dataframe_creation_dispatch.dispatch(backend)
# Call `DataFrameBackendEntrypoint.to_backend`
return backend_entrypoint.to_backend(self, **kwargs)

def dot(self, other, meta=no_default):
if not isinstance(other, FrameBase):
Expand Down Expand Up @@ -1981,7 +1988,7 @@ def __getattr__(self, key):
# Check if key is in columns if key
# is not a normal attribute
if key in self.expr._meta.columns:
return Series(self.expr[key])
return new_collection(self.expr[key])
raise err
except AttributeError:
# Fall back to `BaseFrame.__getattr__`
Expand Down Expand Up @@ -3156,17 +3163,9 @@ def __array__(self):

def new_collection(expr):
"""Create new collection from an expr"""

meta = expr._meta
expr._name # Ensure backend is imported
if is_dataframe_like(meta):
return DataFrame(expr)
elif is_series_like(meta):
return Series(expr)
elif is_index_like(meta):
return Index(expr)
else:
return Scalar(expr)
return get_collection_type(meta)(expr)
Comment on lines 3164 to +3168
Copy link
Member Author

@rjzamora rjzamora Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we have been using new_collection from the beginning, because we knew we needed to make this change to support other backends (i.e. cudf) long term.

The behavior of new_collection is now consistent with the behavior of new_dd_object in dask.dataframe.



def optimize(collection, fuse=True):
Expand Down Expand Up @@ -3233,6 +3232,7 @@ def from_graph(*args, **kwargs):
return new_collection(FromGraph(*args, **kwargs))


@dataframe_creation_dispatch.register_inplace("pandas")
def from_dict(
data,
npartitions,
Expand Down
5 changes: 5 additions & 0 deletions dask_expr/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from dask.utils import Dispatch

get_collection_type = Dispatch("get_collection_type")
Copy link
Member Author

@rjzamora rjzamora Jan 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we will need many dispatch functions here. So, we could also define this in dask.datafame.dispatch.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is now publicly exposed as dask_expr.get_collection_type. This can always be moved to dask.datafame.dispatch in the future.

Loading