Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Series.__contains__ #1480

Merged
merged 10 commits into from
Dec 3, 2024
10 changes: 10 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,16 @@ def __iter__(self: Self) -> Iterator[Any]:
for x in self._native_series.__iter__()
)

def __contains__(self: Self, other: Any) -> bool:
import pyarrow as pa # ignore-banned-imports
import pyarrow.compute as pc # ignore-banned-imports

native_series = self._native_series
return maybe_extract_py_scalar( # type: ignore[no-any-return]
pc.is_in(pa.scalar(other, type=native_series.type), native_series),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if other is of a type not convertible to native_series.type, will this raise? I'm wondering if we should return False in such a case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like polars would also raise here:

In [2]: 'foo' in pl.Series([1,2,3])
---------------------------------------------------------------------------
InvalidOperationError                     Traceback (most recent call last)
Cell In[2], line 1
----> 1 'foo' in pl.Series([1,2,3])

File ~/scratch/.venv/lib/python3.12/site-packages/polars/series/series.py:1233, in Series.__contains__(self, item)
   1231 if item is None:
   1232     return self.has_nulls()
-> 1233 return self.implode().list.contains(item).item()

File ~/scratch/.venv/lib/python3.12/site-packages/polars/series/utils.py:106, in call_expr.<locals>.wrapper(self, *args, **kwargs)
    104     expr = getattr(expr, namespace)
    105 f = getattr(expr, func.__name__)
--> 106 return s.to_frame().select_seq(f(*args, **kwargs)).to_series()

File ~/scratch/.venv/lib/python3.12/site-packages/polars/dataframe/frame.py:9138, in DataFrame.select_seq(self, *exprs, **named_exprs)
   9115 def select_seq(
   9116     self, *exprs: IntoExpr | Iterable[IntoExpr], **named_exprs: IntoExpr
   9117 ) -> DataFrame:
   9118     """
   9119     Select columns from this DataFrame.
   9120
   (...)
   9136     select
   9137     """
-> 9138     return self.lazy().select_seq(*exprs, **named_exprs).collect(_eager=True)

File ~/scratch/.venv/lib/python3.12/site-packages/polars/lazyframe/frame.py:2029, in LazyFrame.collect(self, type_coercion, predicate_pushdown, projection_pushdown, simplify_expression, slice_pushdown, comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, collapse_joins, no_optimization, streaming, engine, background, _eager, **_kwargs)
   2027 # Only for testing purposes
   2028 callback = _kwargs.get("post_opt_callback", callback)
-> 2029 return wrap_df(ldf.collect(callback))

InvalidOperationError: is_in operation not supported for dtypes `str` and `list[i64]`

probably ok to just follow them here then

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still, probably good to have a test to check it raises? we can leave the expression unification for a separate topic

Copy link
Member Author

@FBruzzesi FBruzzesi Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add a test and make sure that same exception is raised πŸ‘Œ

Copy link
Member Author

@FBruzzesi FBruzzesi Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok this is definitly wrong behaviour for arrow as the casting would convert from float to int, leading to the following:

import pyarrow as pa
import narwhals as nw

data = [100, 200, None]
s = nw.from_native(pa.table({"a": data}), eager_only=True)["a"]

100.3 in s
True

On the Exception side, I would not be sure how to catch the error for pandas as:

(pd.Series([1,2,3]) == "foo").any()
np.False_

In a way, I find polars behaviour more surprising

Copy link
Member

@MarcoGorelli MarcoGorelli Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh we could just dispatch to (s._compliant_series == other).any() for all backends πŸ˜„ that would still be better than the status quo

return_py_scalar=True,
)

@property
def shape(self: Self) -> tuple[int]:
return (len(self._native_series),)
Expand Down
5 changes: 5 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,11 @@ def rolling_mean(
def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

def __contains__(self: Self, other: Any) -> bool:
if other is None:
return self._native_series.isna().any() # type: ignore[no-any-return]
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
return self._native_series.isin({other}).any() # type: ignore[no-any-return]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this over return (self._native_series == other).any()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No strong feelings πŸ˜‚ I guess I went for series.isin(...) as I initially fell for the other in series

Copy link
Member

@MarcoGorelli MarcoGorelli Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"fell for" πŸ˜„ love that

from a quick test, it looks a lot faster

In [13]: s = pd.Series(rng.integers(0, 10, size=1_000_000))

In [14]: %timeit s.isin({0}).any()
6.58 ms Β± 687 ΞΌs per loop (mean Β± std. dev. of 7 runs, 100 loops each)

In [15]: %timeit (s == 0).any()
326 ΞΌs Β± 45.8 ΞΌs per loop (mean Β± std. dev. of 7 runs, 1,000 loops each)


def is_finite(self: Self) -> Self:
s = self._native_series
return self._from_native_series((s > float("-inf")) & (s < float("inf")))
Expand Down
3 changes: 3 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,9 @@ def rolling_mean(
def __iter__(self: Self) -> Iterator[Any]:
yield from self._compliant_series.__iter__()

def __contains__(self: Self, other: Any) -> bool:
return self._compliant_series.__contains__(other) # type: ignore[no-any-return]

@property
def str(self: Self) -> SeriesStringNamespace[Self]:
return SeriesStringNamespace(self)
Expand Down
23 changes: 23 additions & 0 deletions tests/series_only/__contains___test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

import narwhals.stable.v1 as nw

if TYPE_CHECKING:
from tests.utils import ConstructorEager

data = [100, 200, None]


@pytest.mark.parametrize(("other", "expected"), [(100, True), (None, True), (3, False)])
def test_contains(
constructor_eager: ConstructorEager,
other: int | None,
expected: bool, # noqa: FBT001
) -> None:
s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"]

assert (other in s) == expected
1 change: 1 addition & 0 deletions utils/check_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"value_counts",
"zip_with",
"__iter__",
"__contains__",
}
BASE_DTYPES = {
"NumericType",
Expand Down
Loading