Skip to content

Commit

Permalink
feat: Disallow order-dependent expressions from being passed to nw.La…
Browse files Browse the repository at this point in the history
…zyFrame (#1806)
  • Loading branch information
MarcoGorelli authored Jan 19, 2025
1 parent c94476c commit bb8d80a
Show file tree
Hide file tree
Showing 39 changed files with 667 additions and 340 deletions.
10 changes: 5 additions & 5 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def shift(self, n: int) -> Self:
)

def cum_sum(self: Self, *, reverse: bool) -> Self:
if reverse:
if reverse: # pragma: no cover
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

Expand All @@ -371,7 +371,7 @@ def cum_sum(self: Self, *, reverse: bool) -> Self:
)

def cum_count(self: Self, *, reverse: bool) -> Self:
if reverse:
if reverse: # pragma: no cover
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

Expand All @@ -382,7 +382,7 @@ def cum_count(self: Self, *, reverse: bool) -> Self:
)

def cum_min(self: Self, *, reverse: bool) -> Self:
if reverse:
if reverse: # pragma: no cover
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

Expand All @@ -393,7 +393,7 @@ def cum_min(self: Self, *, reverse: bool) -> Self:
)

def cum_max(self: Self, *, reverse: bool) -> Self:
if reverse:
if reverse: # pragma: no cover
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

Expand All @@ -404,7 +404,7 @@ def cum_max(self: Self, *, reverse: bool) -> Self:
)

def cum_prod(self: Self, *, reverse: bool) -> Self:
if reverse:
if reverse: # pragma: no cover
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)

Expand Down
4 changes: 2 additions & 2 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def var(

try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx

return partial(dx._groupby.GroupBy.var, ddof=ddof)
Expand All @@ -66,7 +66,7 @@ def std(

try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx

return partial(dx._groupby.GroupBy.std, ddof=ddof)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def add_row_index(
def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError:
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx

if not dx._expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
Expand Down
8 changes: 8 additions & 0 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from narwhals.typing import CompliantNamespace
from narwhals.typing import CompliantSeries
from narwhals.typing import CompliantSeriesT_co
from narwhals.typing import IntoExpr

IntoCompliantExpr: TypeAlias = (
CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co
Expand Down Expand Up @@ -334,3 +335,10 @@ def extract_compliant(
if isinstance(other, Series):
return other._compliant_series
return other


def operation_is_order_dependent(*args: IntoExpr | Any) -> bool:
# If an arg is an Expr, we look at `_is_order_dependent`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a') + 1) or a column name,
# neither of which is order-dependent, so we default to `False`.
return any(getattr(x, "_is_order_dependent", False) for x in args)
76 changes: 58 additions & 18 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
Expand All @@ -15,6 +16,7 @@

from narwhals.dependencies import get_polars
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import OrderDependentExprError
from narwhals.schema import Schema
from narwhals.translate import to_native
from narwhals.utils import find_stacklevel
Expand Down Expand Up @@ -70,25 +72,9 @@ def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any:
kwargs = {k: self._extract_compliant(v) for k, v in kwargs.items()}
return args, kwargs

@abstractmethod
def _extract_compliant(self, arg: Any) -> Any:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series):
return arg._compliant_series
if isinstance(arg, Expr):
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)):
msg = (
f"Expected Narwhals object, got: {type(arg)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)
return arg
raise NotImplementedError

@property
def schema(self) -> Schema:
Expand Down Expand Up @@ -361,6 +347,26 @@ class DataFrame(BaseFrame[DataFrameT]):
```
"""

def _extract_compliant(self, arg: Any) -> Any:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series):
return arg._compliant_series
if isinstance(arg, Expr):
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)):
msg = (
f"Expected Narwhals object, got: {type(arg)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)
return arg

@property
def _series(self) -> type[Series[Any]]:
from narwhals.series import Series
Expand Down Expand Up @@ -3621,6 +3627,40 @@ class LazyFrame(BaseFrame[FrameT]):
```
"""

def _extract_compliant(self, arg: Any) -> Any:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, BaseFrame):
return arg._compliant_frame
if isinstance(arg, Series): # pragma: no cover
msg = "Binary operations between Series and LazyFrame are not supported."
raise TypeError(msg)
if isinstance(arg, Expr):
if arg._is_order_dependent:
msg = (
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
"Hints:\n"
"- Instead of `lf.select(nw.col('a').sort())`, use `lf.select('a').sort()\n"
"- Instead of `lf.select(nw.col('a').head())`, use `lf.select('a').head()\n"
"- `Expr.cum_sum`, and other such expressions, are not currently supported.\n"
" In a future version of Narwhals, a `order_by` argument will be added and \n"
" they will be supported."
)
raise OrderDependentExprError(msg)
return arg._to_compliant_expr(self.__narwhals_namespace__())
if get_polars() is not None and "polars" in str(type(arg)): # pragma: no cover
msg = (
f"Expected Narwhals object, got: {type(arg)}.\n\n"
"Perhaps you:\n"
"- Forgot a `nw.from_native` somewhere?\n"
"- Used `pl.col` instead of `nw.col`?"
)
raise TypeError(msg)
# TODO(unassigned): should this line even be reachable? Should we
# be raising here?
return arg # pragma: no cover

@property
def _dataframe(self) -> type[DataFrame[Any]]:
return DataFrame
Expand Down
2 changes: 1 addition & 1 deletion narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_ibis() -> Any:
return sys.modules.get("ibis", None)


def get_dask_expr() -> Any:
def get_dask_expr() -> Any: # pragma: no cover
"""Get dask_expr module (if already imported - else return None)."""
return sys.modules.get("dask_expr", None)

Expand Down
8 changes: 8 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ def from_expr_name(cls, expr_name: str) -> AnonymousExprError:
return AnonymousExprError(message)


class OrderDependentExprError(ValueError):
"""Exception raised when trying to use an order-dependent expressions with LazyFrames."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)


class UnsupportedDTypeError(ValueError):
"""Exception raised when trying to convert to a DType which is not supported by the given backend."""

Expand Down
Loading

0 comments on commit bb8d80a

Please sign in to comment.