Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Series.__contains__ #1480

Merged
merged 10 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,22 @@
from narwhals.dtypes import UInt64
from narwhals.dtypes import Unknown
from narwhals.expr import Expr
from narwhals.expr import all_ as all # noqa: A004
from narwhals.expr import all_ as all
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok something must have gone wrong with the pre-commit sync :|

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ€” i'm so confused πŸ˜„

from narwhals.expr import all_horizontal
from narwhals.expr import any_horizontal
from narwhals.expr import col
from narwhals.expr import concat_str
from narwhals.expr import len_ as len # noqa: A004
from narwhals.expr import len_ as len
from narwhals.expr import lit
from narwhals.expr import max # noqa: A004
from narwhals.expr import max
from narwhals.expr import max_horizontal
from narwhals.expr import mean
from narwhals.expr import mean_horizontal
from narwhals.expr import median
from narwhals.expr import min # noqa: A004
from narwhals.expr import min
from narwhals.expr import min_horizontal
from narwhals.expr import nth
from narwhals.expr import sum # noqa: A004
from narwhals.expr import sum
from narwhals.expr import sum_horizontal
from narwhals.expr import when
from narwhals.functions import concat
Expand Down
19 changes: 19 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,25 @@ def __iter__(self: Self) -> Iterator[Any]:
for x in self._native_series.__iter__()
)

def __contains__(self: Self, other: Any) -> bool:
from pyarrow import ArrowNotImplementedError # ignore-banned-imports
from pyarrow import ArrowTypeError # ignore-banned-imports

try:
import pyarrow as pa # ignore-banned-imports
import pyarrow.compute as pc # ignore-banned-imports

native_series = self._native_series
return maybe_extract_py_scalar( # type: ignore[no-any-return]
pc.is_in(pa.scalar(other), native_series),
return_py_scalar=True,
)
except (ArrowNotImplementedError, ArrowTypeError) as exc:
from narwhals.exceptions import InvalidOperationError

msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}."
raise InvalidOperationError(msg) from exc

@property
def shape(self: Self) -> tuple[int]:
return (len(self._native_series),)
Expand Down
7 changes: 7 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,13 @@ def rolling_mean(
def __iter__(self: Self) -> Iterator[Any]:
yield from self._native_series.__iter__()

def __contains__(self: Self, other: Any) -> bool:
return ( # type: ignore[no-any-return]
self._native_series.isna().any()
if other is None
else (self._native_series == other).any()
)

def is_finite(self: Self) -> Self:
s = self._native_series
return self._from_native_series((s > float("-inf")) & (s < float("inf")))
Expand Down
11 changes: 11 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,17 @@ def cum_count(self: Self, *, reverse: bool) -> Self:

return self._from_native_series(result)

def __contains__(self: Self, other: Any) -> bool:
from polars.exceptions import InvalidOperationError as PlInvalidOperationError

try:
return self._native_series.__contains__(other)
except PlInvalidOperationError as exc:
from narwhals.exceptions import InvalidOperationError

msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}."
raise InvalidOperationError(msg) from exc

@property
def dt(self: Self) -> PolarsSeriesDateTimeNamespace:
return PolarsSeriesDateTimeNamespace(self)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3324,6 +3324,9 @@ def rolling_mean(
def __iter__(self: Self) -> Iterator[Any]:
yield from self._compliant_series.__iter__()

def __contains__(self: Self, other: Any) -> bool:
return self._compliant_series.__contains__(other) # type: ignore[no-any-return]

@property
def str(self: Self) -> SeriesStringNamespace[Self]:
return SeriesStringNamespace(self)
Expand Down
2 changes: 1 addition & 1 deletion narwhals/stable/v1/selectors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from narwhals.selectors import all # noqa: A004
from narwhals.selectors import all
from narwhals.selectors import boolean
from narwhals.selectors import by_dtype
from narwhals.selectors import categorical
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ lint.select = [
]
lint.ignore = [
"A001",
"A004",
"ARG002",
"ANN401",
"C901",
Expand Down
2 changes: 1 addition & 1 deletion tests/selectors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

import narwhals.stable.v1 as nw
from narwhals.stable.v1.selectors import all # noqa: A004
from narwhals.stable.v1.selectors import all
from narwhals.stable.v1.selectors import boolean
from narwhals.stable.v1.selectors import by_dtype
from narwhals.stable.v1.selectors import categorical
Expand Down
44 changes: 44 additions & 0 deletions tests/series_only/__contains___test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import pytest

import narwhals.stable.v1 as nw
from narwhals.exceptions import InvalidOperationError

if TYPE_CHECKING:
from tests.utils import ConstructorEager

data = [100, 200, None]


@pytest.mark.parametrize(
("other", "expected"), [(100, True), (None, True), (1, False), (100.314, False)]
)
def test_contains(
constructor_eager: ConstructorEager,
other: int | None,
expected: bool, # noqa: FBT001
) -> None:
s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"]

assert (other in s) == expected


@pytest.mark.parametrize("other", ["foo", [1, 2, 3]])
def test_contains_invalid_type(
request: pytest.FixtureRequest,
constructor_eager: ConstructorEager,
other: Any,
) -> None:
if "polars" not in str(constructor_eager) and "pyarrow_table" not in str(
constructor_eager
):
request.applymarker(pytest.mark.xfail)

s = nw.from_native(constructor_eager({"a": data}), eager_only=True)["a"]

with pytest.raises(InvalidOperationError):
_ = other in s
1 change: 1 addition & 0 deletions utils/check_api_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"value_counts",
"zip_with",
"__iter__",
"__contains__",
}
BASE_DTYPES = {
"NumericType",
Expand Down
Loading