diff --git a/dask_expr/__init__.py b/dask_expr/__init__.py index 55da24c2..0eebc0d1 100644 --- a/dask_expr/__init__.py +++ b/dask_expr/__init__.py @@ -1,5 +1,6 @@ from dask_expr import _version, datasets from dask_expr._collection import * +from dask_expr._dispatch import get_collection_type from dask_expr._dummies import get_dummies from dask_expr.io._delayed import from_delayed from dask_expr.io.bag import to_bag diff --git a/dask_expr/_backends.py b/dask_expr/_backends.py new file mode 100644 index 00000000..2a9b93e0 --- /dev/null +++ b/dask_expr/_backends.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import pandas as pd +from dask.backends import CreationDispatch +from dask.dataframe.backends import DataFrameBackendEntrypoint + +from dask_expr._dispatch import get_collection_type + +dataframe_creation_dispatch = CreationDispatch( + module_name="dataframe", + default="pandas", + entrypoint_root="dask_expr", + entrypoint_class=DataFrameBackendEntrypoint, + name="dataframe_creation_dispatch", +) + + +class PandasBackendEntrypoint(DataFrameBackendEntrypoint): + """Pandas-Backend Entrypoint Class for Dask-Expressions + + Note that all DataFrame-creation functions are defined + and registered 'in-place'. + """ + + @classmethod + def to_backend_dispatch(cls): + from dask.dataframe.dispatch import to_pandas_dispatch + + return to_pandas_dispatch + + @classmethod + def to_backend(cls, data, **kwargs): + if isinstance(data._meta, (pd.DataFrame, pd.Series, pd.Index)): + # Already a pandas-backed collection + return data + return data.map_partitions(cls.to_backend_dispatch(), **kwargs) + + +dataframe_creation_dispatch.register_backend("pandas", PandasBackendEntrypoint()) + + +@get_collection_type.register(pd.Series) +def get_collection_type_series(_): + from dask_expr._collection import Series + + return Series + + +@get_collection_type.register(pd.DataFrame) +def get_collection_type_dataframe(_): + from dask_expr._collection import DataFrame + + return DataFrame + + +@get_collection_type.register(pd.Index) +def get_collection_type_index(_): + from dask_expr._collection import Index + + return Index + + +@get_collection_type.register(object) +def get_collection_type_object(_): + from dask_expr._collection import Scalar + + return Scalar + + +###################################### +# cuDF: Pandas Dataframes on the GPU # +###################################### + + +@get_collection_type.register_lazy("cudf") +def _register_cudf(): + import dask_cudf # noqa: F401 diff --git a/dask_expr/_collection.py b/dask_expr/_collection.py index b2099470..ac328348 100644 --- a/dask_expr/_collection.py +++ b/dask_expr/_collection.py @@ -26,7 +26,6 @@ has_parallel_type, is_arraylike, is_dataframe_like, - is_index_like, is_series_like, meta_warning, new_dd_object, @@ -57,12 +56,15 @@ from pandas.api.types import is_timedelta64_dtype from tlz import first +import dask_expr._backends # noqa: F401 from dask_expr import _expr as expr from dask_expr._align import AlignPartitions +from dask_expr._backends import dataframe_creation_dispatch from dask_expr._categorical import CategoricalAccessor, Categorize, GetCategories from dask_expr._concat import Concat from dask_expr._datetime import DatetimeAccessor from dask_expr._describe import DescribeNonNumeric, DescribeNumeric +from dask_expr._dispatch import get_collection_type from dask_expr._expr import ( BFill, Diff, @@ -1578,9 +1580,14 @@ def to_backend(self, backend: str | None = None, **kwargs): ------- DataFrame, Series or Index """ - from dask.dataframe.io import to_backend + from dask_expr._backends import dataframe_creation_dispatch - return to_backend(self.to_dask_dataframe(), backend=backend, **kwargs) + # Get desired backend + backend = backend or dataframe_creation_dispatch.backend + # Check that "backend" has a registered entrypoint + backend_entrypoint = dataframe_creation_dispatch.dispatch(backend) + # Call `DataFrameBackendEntrypoint.to_backend` + return backend_entrypoint.to_backend(self, **kwargs) def dot(self, other, meta=no_default): if not isinstance(other, FrameBase): @@ -1981,7 +1988,7 @@ def __getattr__(self, key): # Check if key is in columns if key # is not a normal attribute if key in self.expr._meta.columns: - return Series(self.expr[key]) + return new_collection(self.expr[key]) raise err except AttributeError: # Fall back to `BaseFrame.__getattr__` @@ -3159,17 +3166,9 @@ def __array__(self): def new_collection(expr): """Create new collection from an expr""" - meta = expr._meta expr._name # Ensure backend is imported - if is_dataframe_like(meta): - return DataFrame(expr) - elif is_series_like(meta): - return Series(expr) - elif is_index_like(meta): - return Index(expr) - else: - return Scalar(expr) + return get_collection_type(meta)(expr) def optimize(collection, fuse=True): @@ -3236,6 +3235,7 @@ def from_graph(*args, **kwargs): return new_collection(FromGraph(*args, **kwargs)) +@dataframe_creation_dispatch.register_inplace("pandas") def from_dict( data, npartitions, diff --git a/dask_expr/_dispatch.py b/dask_expr/_dispatch.py new file mode 100644 index 00000000..dee5178e --- /dev/null +++ b/dask_expr/_dispatch.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from dask.utils import Dispatch + +get_collection_type = Dispatch("get_collection_type")