Skip to content

Commit

Permalink
Still the test case is failing
Browse files Browse the repository at this point in the history
  • Loading branch information
MUKESHRAJMAHENDRAN committed Dec 4, 2024
1 parent 0e5cf1a commit acfc6e9
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 58 deletions.
4 changes: 2 additions & 2 deletions docs/api-reference/expr.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
- alias
- all
- any
- argmin
- argmax
- arg_min
- arg_max
- arg_true
- cast
- count
Expand Down
4 changes: 2 additions & 2 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ Here are the top-level functions available in Narwhals.
- all
- all_horizontal
- any_horizontal
- argmax
- argmin
- arg_max
- arg_min
- col
- concat
- concat_str
Expand Down
4 changes: 2 additions & 2 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
- alias
- all
- any
- argmin
- argmax
- arg_min
- arg_max
- arg_true
- cast
- clip
Expand Down
8 changes: 4 additions & 4 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
from narwhals.expr import all_ as all
from narwhals.expr import all_horizontal
from narwhals.expr import any_horizontal
from narwhals.expr import argmax
from narwhals.expr import argmin
from narwhals.expr import arg_max
from narwhals.expr import arg_min
from narwhals.expr import col
from narwhals.expr import concat_str
from narwhals.expr import len_ as len
Expand Down Expand Up @@ -107,8 +107,8 @@
"all",
"all_horizontal",
"any_horizontal",
"argmax",
"argmin",
"arg_max",
"arg_min",
"col",
"concat",
"concat_str",
Expand Down
16 changes: 8 additions & 8 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def median(self) -> Self:
"""
return self.__class__(lambda plx: self._call(plx).median())

def argmin(self) -> Self:
def arg_min(self) -> Self:
"""Get the index of the minimum value.
Returns:
Expand Down Expand Up @@ -665,9 +665,9 @@ def argmin(self) -> Self:
a: [[0]]
b: [[0]]
"""
return self.__class__(lambda plx: self._call(plx).argmin())
return self.__class__(lambda plx: self._call(plx).arg_min())

def argmax(self) -> Self:
def arg_max(self) -> Self:
"""Get the index of the maximum value.
Returns:
Expand Down Expand Up @@ -711,7 +711,7 @@ def argmax(self) -> Self:
a: [[2]]
b: [[2]]
"""
return self.__class__(lambda plx: self._call(plx).argmax())
return self.__class__(lambda plx: self._call(plx).arg_max())

def std(self, *, ddof: int = 1) -> Self:
"""Get standard deviation.
Expand Down Expand Up @@ -5957,7 +5957,7 @@ def median(*columns: str) -> Expr:
return Expr(lambda plx: plx.median(*columns))


def argmin(*columns: str) -> Expr:
def arg_min(*columns: str) -> Expr:
"""Return the index of the minimum value.
Note:
Expand Down Expand Up @@ -6005,7 +6005,7 @@ def argmin(*columns: str) -> Expr:
----
b: [[0]]
"""
return Expr(lambda plx: plx.argmin(*columns))
return Expr(lambda plx: plx.arg_min(*columns))


def min(*columns: str) -> Expr:
Expand Down Expand Up @@ -6059,7 +6059,7 @@ def min(*columns: str) -> Expr:
return Expr(lambda plx: plx.min(*columns))


def argmax(*columns: str) -> Expr:
def arg_max(*columns: str) -> Expr:
"""Return the index of the maximum value.
Note:
Expand Down Expand Up @@ -6107,7 +6107,7 @@ def argmax(*columns: str) -> Expr:
----
a: [[1]]
"""
return Expr(lambda plx: plx.argmax(*columns))
return Expr(lambda plx: plx.arg_max(*columns))


def max(*columns: str) -> Expr:
Expand Down
8 changes: 4 additions & 4 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def median(self) -> Any:
"""
return self._compliant_series.median()

def argmin(self) -> Any:
def arg_min(self) -> Any:
"""Get the index of the minimum value in this Series.
Examples:
Expand All @@ -685,9 +685,9 @@ def argmin(self) -> Any:
>>> my_library_agnostic_function(s_pl)
0
"""
return self._compliant_series.argmin()
return self._from_compliant_series(self._compliant_series.arg_min())

def argmax(self) -> Any:
def arg_max(self) -> Any:
"""Get the index of the maximum value in this Series.
Examples:
Expand All @@ -712,7 +712,7 @@ def argmax(self) -> Any:
>>> my_library_agnostic_function(s_pl)
2
"""
return self._compliant_series.argmax()
return self._from_compliant_series(self._compliant_series.arg_max())

def skew(self: Self) -> Any:
"""Calculate the sample skewness of the Series.
Expand Down
8 changes: 4 additions & 4 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,7 +2144,7 @@ def min(*columns: str) -> Expr:
return _stableify(nw.min(*columns))


def argmin(*columns: str) -> Expr:
def arg_min(*columns: str) -> Expr:
"""Return the index of the minimum value.
Note:
Expand Down Expand Up @@ -2192,7 +2192,7 @@ def argmin(*columns: str) -> Expr:
----
b: [[0]]
"""
return _stableify(nw.argmin(*columns))
return _stableify(nw.arg_min(*columns))


def max(*columns: str) -> Expr:
Expand Down Expand Up @@ -2246,7 +2246,7 @@ def max(*columns: str) -> Expr:
return _stableify(nw.max(*columns))


def argmax(*columns: str) -> Expr:
def arg_max(*columns: str) -> Expr:
"""Return the index of the maximum value.
Note:
Expand Down Expand Up @@ -2294,7 +2294,7 @@ def argmax(*columns: str) -> Expr:
----
a: [[1]]
"""
return _stableify(nw.argmax(*columns))
return _stableify(nw.arg_max(*columns))


def mean(*columns: str) -> Expr:
Expand Down
66 changes: 50 additions & 16 deletions tests/expr_and_series/arg_max_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,63 @@

import pytest

import narwhals.stable.v1 as nw
import narwhals.stable.v1 as nw # Assuming this is a placeholder for your abstraction layer
from tests.utils import PANDAS_VERSION
from tests.utils import PYARROW_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
# Define sample data for testing
data = {"a": [1, 3, None, 2]}

# Expected results for arg_max
expected = {
"arg_max": [1, 1, None, 1],
}

@pytest.mark.parametrize(
"expr", [nw.col("a", "b", "z").argmax(), nw.argmax("a", "b", "z")]
)
def test_expr_argmax_expr(constructor: Constructor, expr: nw.Expr) -> None:

def test_arg_max_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
# Handle version-specific expected failures
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

if (PANDAS_VERSION < (2, 1) or PYARROW_VERSION < (13,)) and "pandas_pyarrow" in str(
constructor
):
request.applymarker(pytest.mark.xfail)

# Create a DataFrame from the constructor
df = nw.from_native(constructor(data))
result = df.select(expr)
# The index of maximum values: 'a' -> 1 (value 3), 'b' -> 2 (value 6), 'z' -> 2 (value 9)
expected = {"a": [1], "b": [2], "z": [2]}
assert_equal_data(result, expected)

# Test the arg_max expression
result = df.select(
nw.col("a").arg_max().alias("arg_max"),
)

# Assert that the result matches the expected data
assert_equal_data(result, {"arg_max": expected["arg_max"]})


@pytest.mark.parametrize(("col", "expected"), [("a", 1), ("b", 2), ("z", 2)])
def test_expr_argmax_series(
constructor_eager: ConstructorEager, col: str, expected: int
def test_arg_max_series(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
series = nw.from_native(constructor_eager(data), eager_only=True)[col]
result = series.argmax()
assert_equal_data({col: [result]}, {col: [expected]})
# Version-specific xfail setup
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)

if (PANDAS_VERSION < (2, 1) or PYARROW_VERSION < (13,)) and "pandas_pyarrow" in str(
constructor_eager
):
request.applymarker(pytest.mark.xfail)

# Create a DataFrame for eager computation
df = nw.from_native(constructor_eager(data), eager_only=True)

# Test arg_max on series level
result = df.select(
arg_max=df["a"].arg_max(),
)

# Assert that the data matches the expected output
assert_equal_data(result, {"arg_max": expected["arg_max"]})
66 changes: 50 additions & 16 deletions tests/expr_and_series/arg_min_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,63 @@

import pytest

import narwhals.stable.v1 as nw
import narwhals.stable.v1 as nw # Assuming this is a placeholder for your abstraction layer
from tests.utils import PANDAS_VERSION
from tests.utils import PYARROW_VERSION
from tests.utils import Constructor
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]}
# Define sample data for testing
data = {"a": [1, 3, None, 2]}

# Expected results for arg_min
expected = {
"arg_min": [0, 0, None, 0],
}

@pytest.mark.parametrize(
"expr", [nw.col("a", "b", "z").argmin(), nw.argmin("a", "b", "z")]
)
def test_expr_argmin_expr(constructor: Constructor, expr: nw.Expr) -> None:

def test_arg_min_expr(request: pytest.FixtureRequest, constructor: Constructor) -> None:
# Handle version-specific expected failures
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor):
request.applymarker(pytest.mark.xfail)

if (PANDAS_VERSION < (2, 1) or PYARROW_VERSION < (13,)) and "pandas_pyarrow" in str(
constructor
):
request.applymarker(pytest.mark.xfail)

# Create a DataFrame from the constructor
df = nw.from_native(constructor(data))
result = df.select(expr)
# The index of minimum values: 'a' -> 0 (value 1), 'b' -> 0 (value 4), 'z' -> 0 (value 7.0)
expected = {"a": [0], "b": [0], "z": [0]}
assert_equal_data(result, expected)

# Test the arg_min expression
result = df.select(
nw.col("a").arg_min().alias("arg_min"),
)

# Assert that the result matches the expected data
assert_equal_data(result, {"arg_min": expected["arg_min"]})


@pytest.mark.parametrize(("col", "expected"), [("a", 0), ("b", 0), ("z", 0)])
def test_expr_argmin_series(
constructor_eager: ConstructorEager, col: str, expected: int
def test_arg_min_series(
request: pytest.FixtureRequest, constructor_eager: ConstructorEager
) -> None:
series = nw.from_native(constructor_eager(data), eager_only=True)[col]
result = series.argmin()
assert_equal_data({col: [result]}, {col: [expected]})
# Version-specific xfail setup
if PYARROW_VERSION < (13, 0, 0) and "pyarrow_table" in str(constructor_eager):
request.applymarker(pytest.mark.xfail)

if (PANDAS_VERSION < (2, 1) or PYARROW_VERSION < (13,)) and "pandas_pyarrow" in str(
constructor_eager
):
request.applymarker(pytest.mark.xfail)

# Create a DataFrame for eager computation
df = nw.from_native(constructor_eager(data), eager_only=True)

# Test arg_min on series level
result = df.select(
arg_min=df["a"].arg_min(),
)

# Assert that the data matches the expected output
assert_equal_data(result, {"arg_min": expected["arg_min"]})

0 comments on commit acfc6e9

Please sign in to comment.