Skip to content

Commit

Permalink
Add basic backend dispatching to dask-expr (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora authored Jan 19, 2024
1 parent 6721f53 commit 13af21d
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 13 deletions.
1 change: 1 addition & 0 deletions dask_expr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dask_expr import _version, datasets
from dask_expr._collection import *
from dask_expr._dispatch import get_collection_type
from dask_expr._dummies import get_dummies
from dask_expr.io._delayed import from_delayed
from dask_expr.io.bag import to_bag
Expand Down
77 changes: 77 additions & 0 deletions dask_expr/_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from __future__ import annotations

import pandas as pd
from dask.backends import CreationDispatch
from dask.dataframe.backends import DataFrameBackendEntrypoint

from dask_expr._dispatch import get_collection_type

dataframe_creation_dispatch = CreationDispatch(
module_name="dataframe",
default="pandas",
entrypoint_root="dask_expr",
entrypoint_class=DataFrameBackendEntrypoint,
name="dataframe_creation_dispatch",
)


class PandasBackendEntrypoint(DataFrameBackendEntrypoint):
"""Pandas-Backend Entrypoint Class for Dask-Expressions
Note that all DataFrame-creation functions are defined
and registered 'in-place'.
"""

@classmethod
def to_backend_dispatch(cls):
from dask.dataframe.dispatch import to_pandas_dispatch

return to_pandas_dispatch

@classmethod
def to_backend(cls, data, **kwargs):
if isinstance(data._meta, (pd.DataFrame, pd.Series, pd.Index)):
# Already a pandas-backed collection
return data
return data.map_partitions(cls.to_backend_dispatch(), **kwargs)


dataframe_creation_dispatch.register_backend("pandas", PandasBackendEntrypoint())


@get_collection_type.register(pd.Series)
def get_collection_type_series(_):
from dask_expr._collection import Series

return Series


@get_collection_type.register(pd.DataFrame)
def get_collection_type_dataframe(_):
from dask_expr._collection import DataFrame

return DataFrame


@get_collection_type.register(pd.Index)
def get_collection_type_index(_):
from dask_expr._collection import Index

return Index


@get_collection_type.register(object)
def get_collection_type_object(_):
from dask_expr._collection import Scalar

return Scalar


######################################
# cuDF: Pandas Dataframes on the GPU #
######################################


@get_collection_type.register_lazy("cudf")
def _register_cudf():
import dask_cudf # noqa: F401
26 changes: 13 additions & 13 deletions dask_expr/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
has_parallel_type,
is_arraylike,
is_dataframe_like,
is_index_like,
is_series_like,
meta_warning,
new_dd_object,
Expand Down Expand Up @@ -57,12 +56,15 @@
from pandas.api.types import is_timedelta64_dtype
from tlz import first

import dask_expr._backends # noqa: F401
from dask_expr import _expr as expr
from dask_expr._align import AlignPartitions
from dask_expr._backends import dataframe_creation_dispatch
from dask_expr._categorical import CategoricalAccessor, Categorize, GetCategories
from dask_expr._concat import Concat
from dask_expr._datetime import DatetimeAccessor
from dask_expr._describe import DescribeNonNumeric, DescribeNumeric
from dask_expr._dispatch import get_collection_type
from dask_expr._expr import (
BFill,
Diff,
Expand Down Expand Up @@ -1578,9 +1580,14 @@ def to_backend(self, backend: str | None = None, **kwargs):
-------
DataFrame, Series or Index
"""
from dask.dataframe.io import to_backend
from dask_expr._backends import dataframe_creation_dispatch

return to_backend(self.to_dask_dataframe(), backend=backend, **kwargs)
# Get desired backend
backend = backend or dataframe_creation_dispatch.backend
# Check that "backend" has a registered entrypoint
backend_entrypoint = dataframe_creation_dispatch.dispatch(backend)
# Call `DataFrameBackendEntrypoint.to_backend`
return backend_entrypoint.to_backend(self, **kwargs)

def dot(self, other, meta=no_default):
if not isinstance(other, FrameBase):
Expand Down Expand Up @@ -1981,7 +1988,7 @@ def __getattr__(self, key):
# Check if key is in columns if key
# is not a normal attribute
if key in self.expr._meta.columns:
return Series(self.expr[key])
return new_collection(self.expr[key])
raise err
except AttributeError:
# Fall back to `BaseFrame.__getattr__`
Expand Down Expand Up @@ -3159,17 +3166,9 @@ def __array__(self):

def new_collection(expr):
"""Create new collection from an expr"""

meta = expr._meta
expr._name # Ensure backend is imported
if is_dataframe_like(meta):
return DataFrame(expr)
elif is_series_like(meta):
return Series(expr)
elif is_index_like(meta):
return Index(expr)
else:
return Scalar(expr)
return get_collection_type(meta)(expr)


def optimize(collection, fuse=True):
Expand Down Expand Up @@ -3236,6 +3235,7 @@ def from_graph(*args, **kwargs):
return new_collection(FromGraph(*args, **kwargs))


@dataframe_creation_dispatch.register_inplace("pandas")
def from_dict(
data,
npartitions,
Expand Down
5 changes: 5 additions & 0 deletions dask_expr/_dispatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import annotations

from dask.utils import Dispatch

get_collection_type = Dispatch("get_collection_type")

0 comments on commit 13af21d

Please sign in to comment.