Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add more Spark Expressions #1724

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
67 changes: 67 additions & 0 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from __future__ import annotations

from copy import copy
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Sequence

from narwhals._spark_like.utils import get_column_name
from narwhals._spark_like.utils import maybe_evaluate
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantExpr

if TYPE_CHECKING:
from narwhals.dtypes import DType

from narwhals.utils import Implementation
from narwhals.utils import parse_version

Expand Down Expand Up @@ -201,6 +208,22 @@ def _alias(df: SparkLikeLazyFrame) -> list[Column]:
kwargs={**self._kwargs, "name": name},
)

def all(self) -> Self:
def _all(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.bool_and(_input)
EdAbati marked this conversation as resolved.
Show resolved Hide resolved

return self._from_call(_all, "all", returns_scalar=True)
Comment on lines +212 to +217
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We simplified a bit the other methods, we can refactor as

Suggested change
def _all(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
return F.bool_and(_input)
return self._from_call(_all, "all", returns_scalar=True)
from pyspark.sql import functions as F # noqa: N812
return self._from_call(F.bool_and, "all", returns_scalar=True)


def any(self) -> Self:
def _any(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.bool_or(_input)

return self._from_call(_any, "any", returns_scalar=True)
Comment on lines +220 to +225
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _any(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
return F.bool_or(_input)
return self._from_call(_any, "any", returns_scalar=True)
from pyspark.sql import functions as F # noqa: N812
return self._from_call(F.bool_or, "any", returns_scalar=True)

same as above


def count(self) -> Self:
def _count(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
Expand Down Expand Up @@ -233,6 +256,50 @@ def _min(_input: Column) -> Column:

return self._from_call(_min, "min", returns_scalar=True)

def null_count(self) -> Self:
def _null_count(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

return F.count_if(F.isnull(_input))

return self._from_call(_null_count, "null_count", returns_scalar=True)

def replace_strict(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tempted to say that this should not be implemented for now and just raise a NotImplementedError. (as we do in Dask)
We would need to be able to access the dataframe (and collect the results) to get the distinct values of the column.

@FBruzzesi and @MarcoGorelli any thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tempted to say that this should not be implemented for now and just raise a NotImplementedError. (as we do in Dask)

Sure we can evaluate if and how to support replace_strict later on. Super excited to ship the rest for now πŸ™ŒπŸΌ

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree!

self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any] | None = None,
default: Any | None = None,
return_dtype: DType | None = None,
) -> Self:
def _replace_strict(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812

if isinstance(old, Mapping):
mapping = old
else:
if new is None:
msg = "`new` argument is required if `old` argument is not a Mapping type"
raise InvalidOperationError(msg)
mapping = dict(zip(old, new)) # QUESTION: check len(old) == len(new)?

mapping_expr = F.create_map([F.lit(x) for x in chain(*mapping.items())])
replacements = mapping_expr[_input]

if default:
replacements = F.coalesce(replacements, F.lit(default))

# QUESTION: check all values mapped?
# we can check that all values are mapped using: F.bool_and(replacements.isNotNull())
# however, I'm not sure how to validate this as an expression - F.assert_true looked promising
# until I realized it will convert the expression to NULL if the condition is True

if return_dtype:
replacements = replacements.cast(return_dtype)

return replacements

return self._from_call(_replace_strict, "replace_strict", returns_scalar=False)

def sum(self) -> Self:
def _sum(_input: Column) -> Column:
from pyspark.sql import functions as F # noqa: N812
Expand Down
20 changes: 20 additions & 0 deletions narwhals/_spark_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
kwargs={"exprs": exprs},
)

def any_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: SparkLikeLazyFrame) -> list[Column]:
cols = [c for _expr in parsed_exprs for c in _expr(df)]
col_name = get_column_name(df, cols[0])
return [reduce(operator.or_, cols).alias(col_name)]

return SparkLikeExpr( # type: ignore[abstract]
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
version=self._version,
kwargs={"exprs": exprs},
)

def col(self, *column_names: str) -> SparkLikeExpr:
return SparkLikeExpr.from_column_names(
*column_names, backend_version=self._backend_version, version=self._version
Expand Down
136 changes: 136 additions & 0 deletions tests/spark_like_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

import narwhals.stable.v1 as nw
from narwhals.exceptions import ColumnNotFoundError
from tests.utils import POLARS_VERSION
from tests.utils import assert_equal_data

if TYPE_CHECKING:
from pyspark.sql import SparkSession

from narwhals.dtypes import DType
from narwhals.typing import IntoFrame
from tests.utils import Constructor

Expand Down Expand Up @@ -285,6 +287,35 @@ def test_allh_all(pyspark_constructor: Constructor) -> None:
assert_equal_data(result, expected)


# copied from tests/expr_and_series/any_horizontal_test.py
@pytest.mark.parametrize("expr1", ["a", nw.col("a")])
@pytest.mark.parametrize("expr2", ["b", nw.col("b")])
def test_anyh(pyspark_constructor: Constructor, expr1: Any, expr2: Any) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(pyspark_constructor(data))
result = df.select(any=nw.any_horizontal(expr1, expr2))

expected = {"any": [False, True, True]}
assert_equal_data(result, expected)


def test_anyh_all(pyspark_constructor: Constructor) -> None:
data = {
"a": [False, False, True],
"b": [False, True, True],
}
df = nw.from_native(pyspark_constructor(data))
result = df.select(any=nw.any_horizontal(nw.all()))
expected = {"any": [False, True, True]}
assert_equal_data(result, expected)
result = df.select(nw.any_horizontal(nw.all()))
expected = {"a": [False, True, True]}
assert_equal_data(result, expected)


# copied from tests/expr_and_series/sum_horizontal_test.py
@pytest.mark.parametrize("col_expr", [nw.col("a"), "a"])
def test_sumh(pyspark_constructor: Constructor, col_expr: Any) -> None:
Expand Down Expand Up @@ -324,6 +355,25 @@ def test_sumh_all(pyspark_constructor: Constructor) -> None:
assert_equal_data(result, expected)


# copied from tests/expr_and_series/any_all_test.py
def test_any_all(pyspark_constructor: Constructor) -> None:
df = nw.from_native(
pyspark_constructor(
{
"a": [True, False, True],
"b": [True, True, True],
"c": [False, False, False],
}
)
)
result = df.select(nw.col("a", "b", "c").all())
expected = {"a": [False], "b": [True], "c": [False]}
assert_equal_data(result, expected)
result = df.select(nw.all().any())
expected = {"a": [True], "b": [True], "c": [False]}
assert_equal_data(result, expected)


# copied from tests/expr_and_series/count_test.py
def test_count(pyspark_constructor: Constructor) -> None:
data = {"a": [1, 2, 3], "b": [4, None, 6], "z": [7.0, None, None]}
Expand Down Expand Up @@ -374,6 +424,92 @@ def test_expr_min_expr(pyspark_constructor: Constructor) -> None:
assert_equal_data(result, expected)


# copied from tests/expr_and_series/null_count_test.py
def test_null_count_expr(pyspark_constructor: Constructor) -> None:
data = {
"a": [1.0, None, None, 3.0],
"b": [1.0, None, 4, 5.0],
}
df = nw.from_native(pyspark_constructor(data))
result = df.select(nw.all().null_count())
expected = {
"a": [2],
"b": [1],
}
assert_equal_data(result, expected)


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
@pytest.mark.parametrize("return_dtype", [nw.String(), None])
def test_replace_strict(
pyspark_constructor: Constructor,
request: pytest.FixtureRequest,
return_dtype: DType | None,
) -> None:
if "dask" in str(pyspark_constructor): # QUESTION: remove?
request.applymarker(pytest.mark.xfail)
df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]}))
result = df.select(
nw.col("a").replace_strict(
[1, 2, 3], ["one", "two", "three"], return_dtype=return_dtype
)
)
assert_equal_data(result, {"a": ["one", "two", "three"]})


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_non_full(
pyspark_constructor: Constructor, request: pytest.FixtureRequest
) -> None:
from polars.exceptions import PolarsError

if "dask" in str(pyspark_constructor): # QUESTION: remove?
request.applymarker(pytest.mark.xfail)
df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]}))
if isinstance(df, nw.LazyFrame):
with pytest.raises((ValueError, PolarsError)):
df.select(
nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64)
).collect()
else:
with pytest.raises((ValueError, PolarsError)):
df.select(nw.col("a").replace_strict([1, 3], [3, 4], return_dtype=nw.Int64))


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_mapping(
pyspark_constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(pyspark_constructor): # QUESTION: remove?
request.applymarker(pytest.mark.xfail)

df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]}))
result = df.select(
nw.col("a").replace_strict(
{1: "one", 2: "two", 3: "three"}, return_dtype=nw.String()
)
)
assert_equal_data(result, {"a": ["one", "two", "three"]})


@pytest.mark.skipif(
POLARS_VERSION < (1, 0), reason="replace_strict only available after 1.0"
)
def test_replace_strict_invalid(pyspark_constructor: Constructor) -> None:
df = nw.from_native(pyspark_constructor({"a": [1, 2, 3]}))
with pytest.raises(
TypeError,
match="`new` argument is required if `old` argument is not a Mapping type",
):
df.select(nw.col("a").replace_strict(old=[1, 2, 3]))


# copied from tests/expr_and_series/min_test.py
@pytest.mark.parametrize("expr", [nw.col("a", "b", "z").sum(), nw.sum("a", "b", "z")])
def test_expr_sum_expr(pyspark_constructor: Constructor, expr: nw.Expr) -> None:
Expand Down