From 68bc7207af34a8f4ba8fc2201c0950607dc94451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Thu, 12 Dec 2024 18:57:56 +0100 Subject: [PATCH 1/6] wip --- narwhals/_polars/expr.py | 32 +++++++++++++++++++------------- narwhals/_polars/series.py | 35 ++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 26 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index e5a0a9749a..ed9482d3dc 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -8,6 +8,7 @@ from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import extract_native from narwhals._polars.utils import narwhals_to_native_dtype +from narwhals.dependencies import get_polars from narwhals.utils import Implementation if TYPE_CHECKING: @@ -60,21 +61,26 @@ def ewm_mean( min_periods: int, ignore_nulls: bool, ) -> Self: - if self._backend_version < (1,): # pragma: no cover - msg = "`ewm_mean` not implemented for polars older than 1.0" - raise NotImplementedError(msg) expr = self._native_expr - return self._from_native_expr( - expr.ewm_mean( - com=com, - span=span, - half_life=half_life, - alpha=alpha, - adjust=adjust, - min_periods=min_periods, - ignore_nulls=ignore_nulls, - ) + + native_expr = expr.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, ) + if self._backend_version < (1,): # pragma: no cover + pl = get_polars() + result = self._from_native_expr( + pl.when(expr.is_null()).then(None).otherwise(native_expr).name.keep() + ) + + else: + result = self._from_native_expr(native_expr) + return result def map_batches( self, diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 635ef6d134..a062b4679a 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -9,6 +9,7 @@ from narwhals._polars.utils import extract_native from narwhals._polars.utils import narwhals_to_native_dtype from narwhals._polars.utils import native_to_narwhals_dtype +from narwhals.dependencies import get_polars from narwhals.utils import Implementation if TYPE_CHECKING: @@ -262,21 +263,29 @@ def ewm_mean( min_periods: int, ignore_nulls: bool, ) -> Self: + native_series = self._native_series + + native_result = native_series.ewm_mean( + com=com, + span=span, + half_life=half_life, + alpha=alpha, + adjust=adjust, + min_periods=min_periods, + ignore_nulls=ignore_nulls, + ) if self._backend_version < (1,): # pragma: no cover - msg = "`ewm_mean` not implemented for polars older than 1.0" - raise NotImplementedError(msg) - expr = self._native_series - return self._from_native_series( - expr.ewm_mean( - com=com, - span=span, - half_life=half_life, - alpha=alpha, - adjust=adjust, - min_periods=min_periods, - ignore_nulls=ignore_nulls, + pl = get_polars() + result = self._from_native_series( + pl.select( + pl.when(~native_series.is_null()).then(native_result).otherwise(None) + )[native_series.name] ) - ) + + else: + result = self._from_native_series(native_result) + + return result def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: if self._backend_version < (0, 20, 6): From 258a81f786a34891bceb1e31741f4d1519e4d2d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Thu, 12 Dec 2024 19:03:01 +0100 Subject: [PATCH 2/6] works on expr and series --- tests/expr_and_series/ewm_test.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/tests/expr_and_series/ewm_test.py b/tests/expr_and_series/ewm_test.py index 5277576ce6..f2fd8727e5 100644 --- a/tests/expr_and_series/ewm_test.py +++ b/tests/expr_and_series/ewm_test.py @@ -4,7 +4,6 @@ import pytest import narwhals.stable.v1 as nw -from tests.utils import POLARS_VERSION from tests.utils import Constructor from tests.utils import ConstructorEager from tests.utils import assert_equal_data @@ -16,9 +15,7 @@ "ignore:`Expr.ewm_mean` is being called from the stable API although considered an unstable feature." ) def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( - "polars" in str(constructor) and POLARS_VERSION < (1,) - ): + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -36,9 +33,7 @@ def test_ewm_mean_expr(request: pytest.FixtureRequest, constructor: Constructor) def test_ewm_mean_series( request: pytest.FixtureRequest, constructor_eager: ConstructorEager ) -> None: - if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")) or ( - "polars" in str(constructor_eager) and POLARS_VERSION < (1,) - ): + if any(x in str(constructor_eager) for x in ("pyarrow_table_", "modin")): request.applymarker(pytest.mark.xfail) series = nw.from_native(constructor_eager(data), eager_only=True)["a"] @@ -75,9 +70,7 @@ def test_ewm_mean_expr_adjust( adjust: bool, # noqa: FBT001 expected: dict[str, list[float]], ) -> None: - if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")) or ( - "polars" in str(constructor) and POLARS_VERSION < (1,) - ): + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin")): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) @@ -137,9 +130,7 @@ def test_ewm_mean_nulls( expected: dict[str, list[float]], constructor: Constructor, ) -> None: - if any( - x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf") - ) or ("polars" in str(constructor) and POLARS_VERSION < (1,)): + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [2.0, 4.0, None, 3.0]})) @@ -154,9 +145,7 @@ def test_ewm_mean_params( request: pytest.FixtureRequest, constructor: Constructor, ) -> None: - if any( - x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf") - ) or ("polars" in str(constructor) and POLARS_VERSION < (1,)): + if any(x in str(constructor) for x in ("pyarrow_table_", "dask", "modin", "cudf")): request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor({"a": [2, 5, 3]})) From 5680c1c914a33fe78f60b7636afa9029c7427838 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Thu, 12 Dec 2024 19:05:00 +0100 Subject: [PATCH 3/6] shorter version of expr --- narwhals/_polars/expr.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index ed9482d3dc..8dd461098a 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -74,13 +74,10 @@ def ewm_mean( ) if self._backend_version < (1,): # pragma: no cover pl = get_polars() - result = self._from_native_expr( + return self._from_native_expr( pl.when(expr.is_null()).then(None).otherwise(native_expr).name.keep() ) - - else: - result = self._from_native_expr(native_expr) - return result + return self._from_native_expr(native_expr) def map_batches( self, From 8b7eb417715582c8404593eed15f304371f42586 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Thu, 12 Dec 2024 19:05:48 +0100 Subject: [PATCH 4/6] shorter version of series --- narwhals/_polars/series.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index a062b4679a..32170ccea4 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -276,16 +276,13 @@ def ewm_mean( ) if self._backend_version < (1,): # pragma: no cover pl = get_polars() - result = self._from_native_series( + return self._from_native_series( pl.select( pl.when(~native_series.is_null()).then(native_result).otherwise(None) )[native_series.name] ) - else: - result = self._from_native_series(native_result) - - return result + return self._from_native_series(native_result) def sort(self: Self, *, descending: bool, nulls_last: bool) -> Self: if self._backend_version < (0, 20, 6): From d550aa6f2a1d2cc3400ff0d5c49e1cbff721c174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dea=20Mar=C3=ADa=20L=C3=A9on?= Date: Fri, 13 Dec 2024 09:32:36 +0100 Subject: [PATCH 5/6] mirror series in expr, use import polars --- narwhals/_polars/expr.py | 6 +++--- narwhals/_polars/series.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 8dd461098a..879bf479cb 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -8,7 +8,6 @@ from narwhals._polars.utils import extract_args_kwargs from narwhals._polars.utils import extract_native from narwhals._polars.utils import narwhals_to_native_dtype -from narwhals.dependencies import get_polars from narwhals.utils import Implementation if TYPE_CHECKING: @@ -73,9 +72,10 @@ def ewm_mean( ignore_nulls=ignore_nulls, ) if self._backend_version < (1,): # pragma: no cover - pl = get_polars() + import polars as pl + return self._from_native_expr( - pl.when(expr.is_null()).then(None).otherwise(native_expr).name.keep() + pl.when(~expr.is_null()).then(native_expr).otherwise(None).name.keep() ) return self._from_native_expr(native_expr) diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index 32170ccea4..424c5bfbe7 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -9,7 +9,6 @@ from narwhals._polars.utils import extract_native from narwhals._polars.utils import narwhals_to_native_dtype from narwhals._polars.utils import native_to_narwhals_dtype -from narwhals.dependencies import get_polars from narwhals.utils import Implementation if TYPE_CHECKING: @@ -275,7 +274,8 @@ def ewm_mean( ignore_nulls=ignore_nulls, ) if self._backend_version < (1,): # pragma: no cover - pl = get_polars() + import polars as pl + return self._from_native_series( pl.select( pl.when(~native_series.is_null()).then(native_result).otherwise(None) From 3f310939bf7ed552e74c8ac4fcafb0b867ff1999 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Fri, 13 Dec 2024 09:37:59 +0000 Subject: [PATCH 6/6] Update narwhals/_polars/expr.py --- narwhals/_polars/expr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/narwhals/_polars/expr.py b/narwhals/_polars/expr.py index 879bf479cb..4c6e59a4bf 100644 --- a/narwhals/_polars/expr.py +++ b/narwhals/_polars/expr.py @@ -75,7 +75,7 @@ def ewm_mean( import polars as pl return self._from_native_expr( - pl.when(~expr.is_null()).then(native_expr).otherwise(None).name.keep() + pl.when(~expr.is_null()).then(native_expr).otherwise(None) ) return self._from_native_expr(native_expr)