Skip to content

Commit

Permalink
Merge branch 'main' into 35-concrete-gridpolars
Browse files Browse the repository at this point in the history
  • Loading branch information
adamamer20 committed Aug 14, 2024
2 parents 54e25d9 + abb4459 commit 9b4c2df
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 8 deletions.
4 changes: 3 additions & 1 deletion mesa_frames/concrete/pandas/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,9 @@ def _df_reindex(
other: Sequence[Hashable] | pd.DataFrame,
index_cols: str | list[str],
) -> pd.DataFrame:
return df.reindex(other)
df = df.reindex(other)
df.index.name = index_cols
return df

def _df_rename_columns(
self,
Expand Down
11 changes: 4 additions & 7 deletions mesa_frames/concrete/polars/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def _df_constructor(
)
if index is not None:
if index_cols is not None:
if isinstance(index_cols, str):
index_cols = [index_cols]
index_df = pl.DataFrame(index, index_cols)
else:
index_df = pl.DataFrame(index)
Expand Down Expand Up @@ -316,12 +318,7 @@ def _df_join(
left_on, right_on = right_on, left_on
how = "left"
return left.join(
right,
on=on,
left_on=left_on,
right_on=right_on,
how=how,
suffix=suffix,
right, on=on, left_on=left_on, right_on=right_on, how=how, suffix=suffix
)

def _df_lt(
Expand Down Expand Up @@ -470,7 +467,7 @@ def _df_reindex(
other = other.select(index_cols)
else:
# If other is a sequence, create a DataFrame with it
other = pl.DataFrame(index_cols=other)
other = pl.Series(name=index_cols, values=other).to_frame()

# Perform a left join to reindex
result = other.join(df, on=index_cols, how="left")
Expand Down
62 changes: 62 additions & 0 deletions tests/pandas/test_mixin_pandas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pandas as pd
import pytest

from mesa_frames.concrete.pandas.mixin import PandasMixin


@pytest.fixture
def df_or():
return PandasMixin()._df_or


@pytest.fixture
def df_0():
return pd.DataFrame(
{
"unique_id": ["x", "y", "z"],
"A": [1, 0, 1],
"B": ["a", "b", "c"],
"C": [True, False, True],
"D": [0, 1, 1],
}
).set_index("unique_id")


@pytest.fixture
def df_1():
return pd.DataFrame(
{
"unique_id": ["z", "a", "b"],
"A": [0, 1, 0],
"B": ["d", "e", "f"],
"C": [False, True, False],
"E": [1, 0, 1],
}
).set_index("unique_id")


def test_df_or(df_or: df_or, df_0: pd.DataFrame, df_1: pd.DataFrame):
# Test comparing the DataFrame with a sequence element-wise along the rows (axis='index')
df_0["F"] = [True, True, False]
df_1["F"] = [False, False, True]
result = df_or(df_0[["C", "F"]], df_1["F"], axis="index")
assert isinstance(result, pd.DataFrame)
assert result["C"].tolist() == [True, False, True]
assert result["F"].tolist() == [True, True, True]

# Test comparing the DataFrame with a sequence element-wise along the columns (axis='columns')
result = df_or(df_0[["C", "F"]], [True, False], axis="columns")
assert isinstance(result, pd.DataFrame)
assert result["C"].tolist() == [True, True, True]
assert result["F"].tolist() == [True, True, False]

# Test comparing DataFrames with index-column alignment
result = df_or(
df_0[["C", "F"]],
df_1[["C", "F"]],
axis="index",
index_cols="unique_id",
)
assert isinstance(result, pd.DataFrame)
assert result["C"].tolist() == [True, False, True]
assert result["F"].tolist() == [True, True, False]
31 changes: 31 additions & 0 deletions tests/polars/test_mixin_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ def test_df_constructor(self, mixin: PolarsMixin):
assert df["num"].to_list() == [1, 2, 3]
assert df["letter"].to_list() == ["a", "b", "c"]

# Test with index > 1 and 1 value
data = {"a": 5}
df = mixin._df_constructor(
data, index=pl.int_range(5, eager=True), index_cols="index"
)
assert isinstance(df, pl.DataFrame)
assert list(df.columns) == ["index", "a"]
assert df["a"].to_list() == [5, 5, 5, 5, 5]
assert df["index"].to_list() == [0, 1, 2, 3, 4]

def test_df_contains(self, mixin: PolarsMixin, df_0: pl.DataFrame):
# Test with list
result = mixin._df_contains(df_0, "A", [5, 2, 3])
Expand Down Expand Up @@ -629,6 +639,27 @@ def test_df_or(self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame)
assert result["C"].to_list() == [True, None, True]
assert result["F"].to_list() == [True, True, False]

def test_df_reindex(
self, mixin: PolarsMixin, df_0: pl.DataFrame, df_1: pl.DataFrame
):
# Test with DataFrame
reindexed = mixin._df_reindex(df_0, df_1, "unique_id")
assert isinstance(reindexed, pl.DataFrame)
assert reindexed["unique_id"].to_list() == ["z", "a", "b"]
assert reindexed["A"].to_list() == [3, None, None]
assert reindexed["B"].to_list() == ["c", None, None]
assert reindexed["C"].to_list() == [True, None, None]
assert reindexed["D"].to_list() == [3, None, None]

# Test with list
reindexed = mixin._df_reindex(df_0, ["z", "a", "b"], "unique_id")
assert isinstance(reindexed, pl.DataFrame)
assert reindexed["unique_id"].to_list() == ["z", "a", "b"]
assert reindexed["A"].to_list() == [3, None, None]
assert reindexed["B"].to_list() == ["c", None, None]
assert reindexed["C"].to_list() == [True, None, None]
assert reindexed["D"].to_list() == [3, None, None]

def test_df_rename_columns(self, mixin: PolarsMixin, df_0: pl.DataFrame):
renamed = mixin._df_rename_columns(df_0, ["A", "B"], ["X", "Y"])
assert renamed.columns == ["unique_id", "X", "Y", "C", "D"]
Expand Down

0 comments on commit 9b4c2df

Please sign in to comment.