Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Preserve null values in old Polars versions for ewm_mean #1574

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions narwhals/_polars/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand Down Expand Up @@ -60,21 +61,23 @@ def ewm_mean(
min_periods: int,
ignore_nulls: bool,
) -> Self:
if self._backend_version < (1,): # pragma: no cover
msg = "`ewm_mean` not implemented for polars older than 1.0"
raise NotImplementedError(msg)
expr = self._native_expr
return self._from_native_expr(
expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)

native_expr = expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
if self._backend_version < (1,): # pragma: no cover
pl = get_polars()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be ok to just import polars here in _polars

return self._from_native_expr(
pl.when(expr.is_null()).then(None).otherwise(native_expr).name.keep()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does

pl.when(~expr.is_null()).then(native_expr).otherwise(None)

work? then it would mirror what you did for series

)
return self._from_native_expr(native_expr)

def map_batches(
self,
Expand Down
32 changes: 19 additions & 13 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from narwhals._polars.utils import extract_native
from narwhals._polars.utils import narwhals_to_native_dtype
from narwhals._polars.utils import native_to_narwhals_dtype
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand Down Expand Up @@ -262,21 +263,26 @@ def ewm_mean(
min_periods: int,
ignore_nulls: bool,
) -> Self:
native_series = self._native_series

native_result = native_series.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
)
if self._backend_version < (1,): # pragma: no cover
msg = "`ewm_mean` not implemented for polars older than 1.0"
raise NotImplementedError(msg)
expr = self._native_series
return self._from_native_series(
expr.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
min_periods=min_periods,
ignore_nulls=ignore_nulls,
pl = get_polars()
return self._from_native_series(
pl.select(
pl.when(~native_series.is_null()).then(native_result).otherwise(None)
)[native_series.name]
)
)

return self._from_native_series(native_result)

def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self:
if self._backend_version < (0, 20, 6):
Expand Down
21 changes: 5 additions & 16 deletions tests/expr_and_series/ewm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pytest

import narwhals.stable.v1 as nw
from tests.utils import POLARS_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data
Expand All @@ -16,9 +15,7 @@
"ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature."
)
def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand All @@ -36,9 +33,7 @@ def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor)
def test_ewm_mean_series(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or (
"polars" in str(constructor_eager) and POLARS_VERSION < (1,)
):
if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")):
request.applymarker(pytest.mark.xfail)

series = nw.from_native(constructor_eager(data), eager_only=True)["a"]
Expand Down Expand Up @@ -75,9 +70,7 @@ def test_ewm_mean_expr_adjust(
adjust: bool, # noqa: FBT001
expected: dict[str, list[float]],
) -> None:
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or (
"polars" in str(constructor) and POLARS_VERSION < (1,)
):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor(data))
Expand Down Expand Up @@ -137,9 +130,7 @@ def test_ewm_mean_nulls(
expected: dict[str, list[float]],
constructor: Constructor,
) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor({"a": [2.0, 4.0, None, 3.0]}))
Expand All @@ -154,9 +145,7 @@ def test_ewm_mean_params(
request: pytest.FixtureRequest,
constructor: Constructor,
) -> None:
if any(
x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")
) or ("polars" in str(constructor) and POLARS_VERSION < (1,)):
if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")):
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor({"a": [2, 5, 3]}))
Expand Down
Loading