diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index 910cae3..0f3599f 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod +from collections.abc import Collection, Iterator, Sequence from copy import copy, deepcopy - -from typing_extensions import Any, Self from typing import Literal -from collections.abc import Collection, Iterator, Sequence -from mesa_frames.types_ import BoolSeries, DataFrame, Mask, Series +from typing_extensions import Any, Self, overload + +from collections.abc import Hashable + +from mesa_frames.types_ import BoolSeries, DataFrame, Index, Mask, Series class CopyMixin(ABC): @@ -149,38 +151,119 @@ def __deepcopy__(self, memo: dict) -> Self: class DataFrameMixin(ABC): + def _df_remove(self, df: DataFrame, mask: Mask, index_cols: str) -> DataFrame: + return self._df_get_masked_df(df, index_cols, mask, negate=True) + @abstractmethod - def _df_add_columns( - self, original_df: DataFrame, new_columns: list[str], data: Any + def _df_add( + self, + df: DataFrame, + other: DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> DataFrame: ... + + @abstractmethod + def _df_all( + self, + df: DataFrame, + name: str, + axis: str = "columns", + index_cols: str | list[str] | None = None, ) -> DataFrame: ... + @abstractmethod + def _df_column_names(self, df: DataFrame) -> list[str]: ... + @abstractmethod def _df_combine_first( - self, original_df: DataFrame, new_df: DataFrame, index_cols: list[str] + self, original_df: DataFrame, new_df: DataFrame, index_cols: str | list[str] ) -> DataFrame: ... + @overload @abstractmethod def _df_concat( self, - dfs: Collection[DataFrame], + objs: Collection[Series], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, + index_cols: str | None = None, + ) -> Series: ... + + @overload + @abstractmethod + def _df_concat( + self, + objs: Collection[DataFrame], + how: Literal["horizontal"] | Literal["vertical"] = "vertical", + ignore_index: bool = False, + index_cols: str | None = None, ) -> DataFrame: ... + @abstractmethod + def _df_concat( + self, + objs: Collection[DataFrame] | Collection[Series], + how: Literal["horizontal"] | Literal["vertical"] = "vertical", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> DataFrame | Series: ... + + @abstractmethod + def _df_contains( + self, + df: DataFrame, + column: str, + values: Sequence[Any], + ) -> BoolSeries: ... + @abstractmethod def _df_constructor( self, data: Sequence[Sequence] | dict[str | Any] | None = None, columns: list[str] | None = None, - index_col: str | list[str] | None = None, + index: Index | None = None, + index_cols: str | list[str] | None = None, dtypes: dict[str, Any] | None = None, ) -> DataFrame: ... + @abstractmethod + def _df_div( + self, + df: DataFrame, + other: DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> DataFrame: ... + + @abstractmethod + def _df_drop_columns( + self, + df: DataFrame, + columns: str | list[str], + ) -> DataFrame: ... + + @abstractmethod + def _df_drop_duplicates( + self, + df: DataFrame, + subset: str | list[str] | None = None, + keep: Literal["first", "last", False] = "first", + ) -> DataFrame: ... + + @abstractmethod + def _df_filter( + self, + df: DataFrame, + condition: BoolSeries, + all: bool = True, + ) -> DataFrame: ... + @abstractmethod def _df_get_bool_mask( self, df: DataFrame, - index_col: str, + index_cols: str | list[str], mask: Mask | None = None, negate: bool = False, ) -> BoolSeries: ... @@ -189,21 +272,88 @@ def _df_get_bool_mask( def _df_get_masked_df( self, df: DataFrame, - index_col: str, + index_cols: str, mask: Mask | None = None, - columns: list[str] | None = None, + columns: str | list[str] | None = None, negate: bool = False, ) -> DataFrame: ... + @abstractmethod + def _df_groupby_cumcount( + self, + df: DataFrame, + by: str | list[str], + ) -> Series: ... + @abstractmethod def _df_iterator(self, df: DataFrame) -> Iterator[dict[str, Any]]: ... @abstractmethod - def _df_norm(self, df: DataFrame) -> DataFrame: ... + def _df_join( + self, + left: DataFrame, + right: DataFrame, + index_cols: str | list[str] | None = None, + on: str | list[str] | None = None, + left_on: str | list[str] | None = None, + right_on: str | list[str] | None = None, + how: Literal["left"] + | Literal["right"] + | Literal["inner"] + | Literal["outer"] + | Literal["cross"] = "left", + suffix="_right", + ) -> DataFrame: ... + + @abstractmethod + def _df_mul( + self, + df: DataFrame, + other: DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> DataFrame: ... + + @abstractmethod + @overload + def _df_norm( + self, + df: DataFrame, + srs_name: str = "norm", + include_cols: Literal[False] = False, + ) -> Series: ... + + @abstractmethod + @overload + def _df_norm( + self, + df: DataFrame, + srs_name: str = "norm", + include_cols: Literal[True] = False, + ) -> DataFrame: ... + + @abstractmethod + def _df_norm( + self, + df: DataFrame, + srs_name: str = "norm", + include_cols: bool = False, + ) -> Series | DataFrame: ... @abstractmethod - def _df_remove( - self, df: DataFrame, ids: Sequence[Any], index_col: str | None = None + def _df_rename_columns( + self, + df: DataFrame, + old_columns: list[str], + new_columns: list[str], + ) -> DataFrame: ... + + @abstractmethod + def _df_reset_index( + self, + df: DataFrame, + index_cols: str | list[str] | None = None, + drop: bool = False, ) -> DataFrame: ... @abstractmethod @@ -217,6 +367,27 @@ def _df_sample( seed: int | None = None, ) -> DataFrame: ... + @abstractmethod + def _df_set_index( + self, + df: DataFrame, + index_name: str, + new_index: Sequence[Hashable] | None = None, + ) -> DataFrame: ... + + @abstractmethod + def _df_with_columns( + self, + original_df: DataFrame, + data: DataFrame + | Series + | Sequence[Sequence] + | dict[str | Any] + | Sequence[Any] + | Any, + new_columns: str | list[str] | None = None, + ) -> DataFrame: ... + @abstractmethod def _srs_constructor( self, @@ -225,3 +396,16 @@ def _srs_constructor( dtype: Any | None = None, index: Sequence[Any] | None = None, ) -> Series: ... + + @abstractmethod + def _srs_contains( + self, + srs: Sequence[Any], + values: Any | Sequence[Any], + ) -> BoolSeries: ... + + @abstractmethod + def _srs_range(self, name: str, start: int, end: int, step: int = 1) -> Series: ... + + @abstractmethod + def _srs_to_df(self, srs: Series, index: Index | None = None) -> DataFrame: ... diff --git a/mesa_frames/concrete/pandas/mixin.py b/mesa_frames/concrete/pandas/mixin.py index be22393..a5f99c6 100644 --- a/mesa_frames/concrete/pandas/mixin.py +++ b/mesa_frames/concrete/pandas/mixin.py @@ -1,111 +1,307 @@ from collections.abc import Collection, Iterator, Sequence from typing import Literal +from collections.abc import Hashable + import numpy as np import pandas as pd -from typing_extensions import Any +from typing_extensions import Any, overload from mesa_frames.abstract.mixin import DataFrameMixin from mesa_frames.types_ import PandasMask class PandasMixin(DataFrameMixin): - def _df_add_columns( - self, original_df: pd.DataFrame, new_columns: list[str], data: Any + def _df_add( + self, + df: pd.DataFrame, + other: pd.DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pd.DataFrame: + return df.add(other=other, axis=axis) + + def _df_all( + self, + df: pd.DataFrame, + name: str, + axis: str = "columns", + index_cols: str | list[str] | None = None, ) -> pd.DataFrame: - original_df[new_columns] = data - return original_df + return df.all(axis).to_frame(name) + + def _df_column_names(self, df: pd.DataFrame) -> list[str]: + return df.columns.tolist() + df.index.names def _df_combine_first( - self, original_df: pd.DataFrame, new_df: pd.DataFrame, index_cols: list[str] + self, + original_df: pd.DataFrame, + new_df: pd.DataFrame, + index_cols: str | list[str], ) -> pd.DataFrame: + if (isinstance(index_cols, str) and index_cols != original_df.index.name) or ( + isinstance(index_cols, list) and index_cols != original_df.index.names + ): + original_df = original_df.set_index(index_cols) + + if (isinstance(index_cols, str) and index_cols != original_df.index.name) or ( + isinstance(index_cols, list) and index_cols != original_df.index.names + ): + new_df = new_df.set_index(index_cols) return original_df.combine_first(new_df) + @overload def _df_concat( self, - dfs: Collection[pd.DataFrame], + objs: Collection[pd.DataFrame], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, - ) -> pd.DataFrame: - return pd.concat( - dfs, axis=0 if how == "vertical" else 1, ignore_index=ignore_index + index_cols: str | None = None, + ) -> pd.DataFrame: ... + + @overload + def _df_concat( + self, + objs: Collection[pd.Series], + how: Literal["horizontal"] | Literal["vertical"] = "vertical", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> pd.Series: ... + + def _df_concat( + self, + objs: Collection[pd.DataFrame] | Collection[pd.Series], + how: Literal["horizontal"] | Literal["vertical"] = "vertical", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> pd.Series | pd.DataFrame: + df = pd.concat( + objs, axis=0 if how == "vertical" else 1, ignore_index=ignore_index ) + if index_cols: + return df.set_index(index_cols) + return df def _df_constructor( self, data: Sequence[Sequence] | dict[str | Any] | None = None, columns: list[str] | None = None, - index_col: str | list[str] | None = None, + index: Sequence[Hashable] | None = None, + index_cols: str | list[str] | None = None, dtypes: dict[str, Any] | None = None, ) -> pd.DataFrame: - df = pd.DataFrame(data=data, columns=columns).astype(dtypes) - if index_col: - df.set_index(index_col) + df = pd.DataFrame(data=data, columns=columns, index=index) + if dtypes: + df = df.astype(dtypes) + if index_cols: + df = df.set_index(index_cols) return df + def _df_contains( + self, + df: pd.DataFrame, + column: str, + values: Sequence[Any], + ) -> pd.Series: + if df.index.name == column: + return pd.Series(values).isin(df.index) + return pd.Series(values).isin(df[column]) + + def _df_filter( + self, + df: pd.DataFrame, + condition: pd.DataFrame, + all: bool = True, + ) -> pd.DataFrame: + if all and isinstance(condition, pd.DataFrame): + return df[condition.all(axis=1)] + return df[condition] + + def _df_div( + self, + df: pd.DataFrame, + other: pd.DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pd.DataFrame: + return df.div(other=other, axis=axis) + + def _df_drop_columns( + self, + df: pd.DataFrame, + columns: str | list[str], + ) -> pd.DataFrame: + return df.drop(columns=columns) + + def _df_drop_duplicates( + self, + df: pd.DataFrame, + subset: str | list[str] | None = None, + keep: Literal["first", "last", False] = "first", + ) -> pd.DataFrame: + return df.drop_duplicates(subset=subset, keep=keep) + def _df_get_bool_mask( self, df: pd.DataFrame, - index_col: str, + index_cols: str | list[str], mask: PandasMask = None, negate: bool = False, ) -> pd.Series: + # Get the index column + if (isinstance(index_cols, str) and df.index.name == index_cols) or ( + isinstance(index_cols, list) and df.index.names == index_cols + ): + srs = df.index + else: + srs = df.set_index(index_cols).index if isinstance(mask, pd.Series) and mask.dtype == bool and len(mask) == len(df): + mask.index = df.index result = mask - elif isinstance(mask, pd.DataFrame): - if mask.index.name == index_col: - result = pd.Series(df.index.isin(mask.index), index=df.index) - elif index_col in mask.columns: - result = pd.Series(df.index.isin(mask[index_col]), index=df.index) - else: - raise ValueError( - f"A DataFrame mask must have a column/index with name {index_col}" - ) - elif mask is None or mask == "all": + elif mask is None: result = pd.Series(True, index=df.index) - elif isinstance(mask, Sequence): - result = pd.Series(df.index.isin(mask), index=df.index) else: - result = pd.Series(df.index.isin([mask]), index=df.index) + if isinstance(mask, pd.DataFrame): + if (isinstance(index_cols, str) and mask.index.name == index_cols) or ( + isinstance(index_cols, list) and mask.index.names == index_cols + ): + mask = mask.index + else: + mask = mask.set_index(index_cols).index + elif isinstance(mask, Collection): + pass + else: # single value + mask = [mask] + result = pd.Series(srs.isin(mask), index=df.index) if negate: result = ~result - return result def _df_get_masked_df( self, df: pd.DataFrame, - index_col: str, + index_cols: str, mask: PandasMask | None = None, - columns: list[str] | None = None, + columns: str | list[str] | None = None, negate: bool = False, ) -> pd.DataFrame: - b_mask = self._df_get_bool_mask(df, index_col, mask, negate) + b_mask = self._df_get_bool_mask(df, index_cols, mask, negate) if columns: return df.loc[b_mask, columns] return df.loc[b_mask] + def _df_groupby_cumcount(self, df: pd.DataFrame, by: str | list[str]) -> pd.Series: + return df.groupby(by).cumcount() + def _df_iterator(self, df: pd.DataFrame) -> Iterator[dict[str, Any]]: for index, row in df.iterrows(): row_dict = row.to_dict() - row_dict["unique_id"] = index + if df.index.name: + row_dict[df.index.name] = index + else: + row_dict["index"] = index yield row_dict - def _df_norm(self, df: pd.DataFrame) -> pd.DataFrame: - return self._df_constructor( - data=[np.linalg.norm(df, axis=1), df.index], - columns=[df.columns, df.index.name], - index_col=df.index.name, + def _df_join( + self, + left: pd.DataFrame, + right: pd.DataFrame, + index_cols: str | list[str] | None = None, + on: str | list[str] | None = None, + left_on: str | list[str] | None = None, + right_on: str | list[str] | None = None, + how: Literal["left"] + | Literal["right"] + | Literal["inner"] + | Literal["outer"] + | Literal["cross"] = "left", + suffix="_right", + ) -> pd.DataFrame: + left_index = False + right_index = False + if on: + left_on = on + right_on = on + if left.index.name and left.index.name == left_on: + left_index = True + left_on = None + if right.index.name and right.index.name == right_on: + right_index = True + right_on = None + # Reset index if it is not used as a key to keep it in the DataFrame + if not left_index and left.index.name: + left = left.reset_index() + if not right_index and right.index.name: + right = right.reset_index() + df = left.merge( + right, + how=how, + left_on=left_on, + right_on=right_on, + left_index=left_index, + right_index=right_index, + suffixes=("", suffix), + ) + if index_cols: + return df.set_index(index_cols) + else: + return df + + def _df_mul( + self, + df: pd.DataFrame, + other: pd.DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pd.DataFrame: + return df.mul(other=other, axis=axis) + + @overload + def _df_norm( + self, + df: pd.DataFrame, + srs_name: str = "norm", + include_cols: Literal[False] = False, + ) -> pd.Series: ... + + @overload + def _df_norm( + self, + df: pd.DataFrame, + srs_name: str = "norm", + include_cols: Literal[True] = True, + ) -> pd.DataFrame: ... + + def _df_norm( + self, + df: pd.DataFrame, + srs_name: str = "norm", + include_cols: bool = False, + ) -> pd.Series | pd.DataFrame: + srs = self._srs_constructor( + np.linalg.norm(df, axis=1), name=srs_name, index=df.index ) + if include_cols: + return self._df_with_columns(df, srs, srs_name) + else: + return srs + + def _df_rename_columns( + self, + df: pd.DataFrame, + old_columns: list[str], + new_columns: list[str], + ) -> pd.DataFrame: + return df.rename(columns=dict(zip(old_columns, new_columns))) - def _df_remove( + def _df_reset_index( self, df: pd.DataFrame, - ids: Sequence[Any], - index_col: str | None = None, + index_cols: str | list[str] | None = None, + drop: bool = False, ) -> pd.DataFrame: - return df[~df.index.isin(ids)] + return df.reset_index(level=index_cols, drop=drop) def _df_sample( self, @@ -116,9 +312,42 @@ def _df_sample( shuffle: bool = False, seed: int | None = None, ) -> pd.DataFrame: - return df.sample( - n=n, frac=frac, replace=with_replacement, shuffle=shuffle, random_state=seed - ) + return df.sample(n=n, frac=frac, replace=with_replacement, random_state=seed) + + def _df_set_index( + self, + df: pd.DataFrame, + index_name: str, + new_index: Sequence[Hashable] | None = None, + ) -> pd.DataFrame: + if new_index is None: + df = df.set_index(index_name) + else: + df = df.set_index(new_index) + df.index.name = index_name + return df + + def _df_with_columns( + self, + original_df: pd.DataFrame, + data: pd.DataFrame + | pd.Series + | Sequence[Sequence] + | dict[str | Any] + | Sequence[Any] + | Any, + new_columns: str | list[str] | None = None, + ) -> pd.DataFrame: + df = original_df.copy() + if isinstance(data, dict): + return df.assign(**data) + elif isinstance(data, pd.DataFrame): + data = data.set_index(df.index) + new_columns = data.columns + elif isinstance(data, pd.Series): + data.index = df.index + df.loc[:, new_columns] = data + return df def _srs_constructor( self, @@ -128,3 +357,26 @@ def _srs_constructor( index: Sequence[Any] | None = None, ) -> pd.Series: return pd.Series(data, name=name, dtype=dtype, index=index) + + def _srs_contains( + self, srs: Sequence[Any], values: Any | Sequence[Any] + ) -> pd.Series: + if isinstance(values, Sequence): + return pd.Series(values, index=values).isin(srs) + else: + return pd.Series(values, index=[values]).isin(srs) + + def _srs_range( + self, + name: str, + start: int, + end: int, + step: int = 1, + ) -> pd.Series: + return pd.Series(np.arange(start, end, step), name=name) + + def _srs_to_df(self, srs: pd.Series, index: pd.Index | None = None) -> pd.DataFrame: + df = srs.to_frame() + if index: + return df.set_index(index) + return df diff --git a/mesa_frames/concrete/polars/mixin.py b/mesa_frames/concrete/polars/mixin.py index c3854ee..bae9b53 100644 --- a/mesa_frames/concrete/polars/mixin.py +++ b/mesa_frames/concrete/polars/mixin.py @@ -2,7 +2,9 @@ from typing import Literal import polars as pl -from typing_extensions import Any +from typing_extensions import Any, overload + +from collections.abc import Hashable from mesa_frames.abstract.mixin import DataFrameMixin from mesa_frames.types_ import PolarsMask @@ -12,15 +14,79 @@ class PolarsMixin(DataFrameMixin): # TODO: complete with other dtypes _dtypes_mapping: dict[str, Any] = {"int64": pl.Int64, "bool": pl.Boolean} - def _df_add_columns( + def _df_add( + self, + df: pl.DataFrame, + other: pl.DataFrame | Sequence[float | int], + axis: Literal["index"] | Literal["columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pl.DataFrame: + if isinstance(other, pl.DataFrame): + if axis == "index": + if index_cols is None: + raise ValueError( + "index_cols must be specified when axis is 'index'" + ) + return ( + df.join(other.select(pl.all().suffix("_add")), on=index_cols) + .with_columns( + [ + (pl.col(col) + pl.col(f"{col}_add")).alias(col) + for col in df.columns + if col not in index_cols + ] + ) + .select(df.columns) + ) + else: + return df.select( + [ + (pl.col(col) + pl.col(other.columns[i])).alias(col) + for i, col in enumerate(df.columns) + ] + ) + elif isinstance(other, Sequence): + if axis == "index": + other_series = pl.Series("addend", other) + return df.with_columns( + [(pl.col(col) + other_series).alias(col) for col in df.columns] + ) + else: + return df.with_columns( + [ + (pl.col(col) + other[i]).alias(col) + for i, col in enumerate(df.columns) + ] + ) + else: + raise ValueError("other must be a DataFrame or a Sequence") + + def _df_all( + self, + df: pl.DataFrame, + name: str, + axis: str = "columns", + index_cols: str | None = None, + ) -> pl.DataFrame: + if axis == "index": + return df.group_by(index_cols).agg(pl.all().all().alias(index_cols)) + return df.select(pl.all().all()) + + def _df_with_columns( self, original_df: pl.DataFrame, new_columns: list[str], data: Any ) -> pl.DataFrame: return original_df.with_columns( **{col: value for col, value in zip(new_columns, data)} ) + def _df_column_names(self, df: pl.DataFrame) -> list[str]: + return df.columns + def _df_combine_first( - self, original_df: pl.DataFrame, new_df: pl.DataFrame, index_cols: list[str] + self, + original_df: pl.DataFrame, + new_df: pl.DataFrame, + index_cols: str | list[str], ) -> pl.DataFrame: new_df = original_df.join(new_df, on=index_cols, how="full", suffix="_right") # Find columns with the _right suffix and update the corresponding original columns @@ -41,30 +107,172 @@ def _df_combine_first( ) return new_df + @overload def _df_concat( self, - dfs: Collection[pl.DataFrame], + objs: Collection[pl.DataFrame], how: Literal["horizontal"] | Literal["vertical"] = "vertical", ignore_index: bool = False, - ) -> pl.DataFrame: + index_cols: str | None = None, + ) -> pl.DataFrame: ... + + @overload + def _df_concat( + self, + objs: Collection[pl.Series], + how: Literal["horizontal"] | Literal["vertical"] = "vertical", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> pl.Series: ... + + def _df_concat( + self, + objs: Collection[pl.DataFrame] | Collection[pl.Series], + how: Literal["horizontal"] | Literal["vertical"] = "vertical", + ignore_index: bool = False, + index_cols: str | None = None, + ) -> pl.Series | pl.DataFrame: return pl.concat( - dfs, how="vertical_relaxed" if how == "vertical" else "horizontal_relaxed" + objs, how="vertical_relaxed" if how == "vertical" else "horizontal_relaxed" ) def _df_constructor( self, data: Sequence[Sequence] | dict[str | Any] | None = None, columns: list[str] | None = None, - index_col: str | list[str] | None = None, + index: Sequence[Hashable] | None = None, + index_cols: str | list[str] | None = None, dtypes: dict[str, str] | None = None, ) -> pl.DataFrame: dtypes = {k: self._dtypes_mapping.get(v, v) for k, v in dtypes.items()} return pl.DataFrame(data=data, schema=dtypes if dtypes else columns) + def _df_contains( + self, + df: pl.DataFrame, + column: str, + values: Sequence[Any], + ) -> pl.Series: + return pl.Series(values, index=values).is_in(df[column]) + + def _df_div( + self, + df: pl.DataFrame, + other: pl.DataFrame | pl.Series | Sequence[float | int], + axis: Literal["index"] | Literal["columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pl.DataFrame: + if isinstance(other, pl.DataFrame): + if axis == "index": + if index_cols is None: + raise ValueError( + "index_cols must be specified when axis is 'index'" + ) + return ( + df.join(other.select(pl.all().suffix("_div")), on=index_cols) + .with_columns( + [ + (pl.col(col) / pl.col(f"{col}_div")).alias(col) + for col in df.columns + if col not in index_cols + ] + ) + .select(df.columns) + ) + else: # axis == "columns" + return df.select( + [ + (pl.col(col) / pl.col(other.columns[i])).alias(col) + for i, col in enumerate(df.columns) + ] + ) + elif isinstance(other, pl.Series): + if axis == "index": + return df.with_columns( + [ + (pl.col(col) / other).alias(col) + for col in df.columns + if col != other.name + ] + ) + else: # axis == "columns" + return df.with_columns( + [ + (pl.col(col) / other[i]).alias(col) + for i, col in enumerate(df.columns) + ] + ) + elif isinstance(other, Sequence): + if axis == "index": + other_series = pl.Series("divisor", other) + return df.with_columns( + [(pl.col(col) / other_series).alias(col) for col in df.columns] + ) + else: # axis == "columns" + return df.with_columns( + [ + (pl.col(col) / other[i]).alias(col) + for i, col in enumerate(df.columns) + ] + ) + else: + raise ValueError("other must be a DataFrame, Series, or Sequence") + + def _df_drop_columns( + self, + df: pl.DataFrame, + columns: str | list[str], + ) -> pl.DataFrame: + return df.drop(columns) + + def _df_drop_duplicates( + self, + df: pl.DataFrame, + subset: str | list[str] | None = None, + keep: Literal["first", "last", False] = "first", + ) -> pl.DataFrame: + # If subset is None, use all columns + if subset is None: + subset = df.columns + # If subset is a string, convert it to a list + elif isinstance(subset, str): + subset = [subset] + + # Determine the sort order based on 'keep' + if keep == "first": + sort_expr = [pl.col(col).rank("dense", reverse=True) for col in subset] + elif keep == "last": + sort_expr = [pl.col(col).rank("dense") for col in subset] + elif keep is False: + # If keep is False, we don't need to sort, just group and filter + return df.group_by(subset).agg(pl.all().first()).sort(subset) + else: + raise ValueError("'keep' must be either 'first', 'last', or False") + + # Add a rank column, sort by it, and keep only the first row of each group + return ( + df.with_columns(pl.struct(sort_expr).alias("__rank")) + .sort("__rank") + .group_by(subset) + .agg(pl.all().first()) + .sort(subset) + .drop("__rank") + ) + + def _df_filter( + self, + df: pl.DataFrame, + condition: pl.Series, + all: bool = True, + ) -> pl.DataFrame: + if all: + return df.filter(pl.all(condition)) + return df.filter(condition) + def _df_get_bool_mask( self, df: pl.DataFrame, - index_col: str, + index_cols: str | list[str], mask: PolarsMask = None, negate: bool = False, ) -> pl.Series | pl.Expr: @@ -75,20 +283,20 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: and len(mask) == len(df) ): return mask - return df[index_col].is_in(mask) + return df[index_cols].is_in(mask) if isinstance(mask, pl.Expr): result = mask elif isinstance(mask, pl.Series): result = bool_mask_from_series(mask) elif isinstance(mask, pl.DataFrame): - if index_col in mask.columns: - result = bool_mask_from_series(mask[index_col]) + if index_cols in mask.columns: + result = bool_mask_from_series(mask[index_cols]) elif len(mask.columns) == 1 and mask.dtypes[0] == pl.Boolean: result = bool_mask_from_series(mask[mask.columns[0]]) else: raise KeyError( - f"DataFrame must have an {index_col} column or a single boolean column." + f"DataFrame must have an {index_cols} column or a single boolean column." ) elif mask is None or mask == "all": result = pl.Series([True] * len(df)) @@ -105,26 +313,140 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: def _df_get_masked_df( self, df: pl.DataFrame, - index_col: str, + index_cols: str, mask: PolarsMask | None = None, columns: list[str] | None = None, negate: bool = False, ) -> pl.DataFrame: - b_mask = self._df_get_bool_mask(df, index_col, mask, negate=negate) + b_mask = self._df_get_bool_mask(df, index_cols, mask, negate=negate) if columns: return df.filter(b_mask)[columns] return df.filter(b_mask) + def _df_groupby_cumcount(self, df: pl.DataFrame, by: str | list[str]) -> pl.Series: + return df.with_columns(pl.col(by).cum_count().alias("cumcount")) + def _df_iterator(self, df: pl.DataFrame) -> Iterator[dict[str, Any]]: return iter(df.iter_rows(named=True)) - def _df_norm(self, df: pl.DataFrame) -> pl.DataFrame: - return df.with_columns(pl.col("*").pow(2).alias("*")).sum_horizontal().sqrt() + def _df_join( + self, + left: pl.DataFrame, + right: pl.DataFrame, + on: str | list[str] | None = None, + left_on: str | list[str] | None = None, + right_on: str | list[str] | None = None, + how: Literal["left"] + | Literal["right"] + | Literal["inner"] + | Literal["outer"] + | Literal["cross"] = "left", + suffix="_right", + ) -> pl.DataFrame: + return left.join( + right, + on=on, + left_on=left_on, + right_on=right_on, + how=how, + lsuffix="", + rsuffix=suffix, + ) + + def _df_mul( + self, + df: pl.DataFrame, + other: pl.DataFrame | Sequence[float | int], + axis: Literal["index", "columns"] = "index", + index_cols: str | list[str] | None = None, + ) -> pl.DataFrame: + if isinstance(other, pl.DataFrame): + if axis == "index": + if index_cols is None: + raise ValueError( + "index_cols must be specified when axis is 'index'" + ) + return ( + df.join(other.select(pl.all().suffix("_mul")), on=index_cols) + .with_columns( + [ + (pl.col(col) * pl.col(f"{col}_mul")).alias(col) + for col in df.columns + if col not in index_cols + ] + ) + .select(df.columns) + ) + else: # axis == "columns" + return df.select( + [ + (pl.col(col) * pl.col(other.columns[i])).alias(col) + for i, col in enumerate(df.columns) + ] + ) + elif isinstance(other, Sequence): + if axis == "index": + other_series = pl.Series("multiplier", other) + return df.with_columns( + [(pl.col(col) * other_series).alias(col) for col in df.columns] + ) + else: + return df.with_columns( + [ + (pl.col(col) * other[i]).alias(col) + for i, col in enumerate(df.columns) + ] + ) + else: + raise ValueError("other must be a DataFrame or a Sequence") + + @overload + def _df_norm( + self, + df: pl.DataFrame, + srs_name: str = "norm", + include_cols: Literal[False] = False, + ) -> pl.Series: ... + + @overload + def _df_norm( + self, + df: pl.Series, + srs_name: str = "norm", + include_cols: Literal[True] = True, + ) -> pl.DataFrame: ... - def _df_remove( - self, df: pl.DataFrame, ids: Sequence[Any], index_col: str | None = None + def _df_norm( + self, + df: pl.DataFrame, + srs_name: str = "norm", + include_cols: bool = False, + ) -> pl.Series | pl.DataFrame: + srs = ( + df.with_columns(pl.col("*").pow(2).alias("*")) + .sum_horizontal() + .sqrt() + .rename(srs_name) + ) + if include_cols: + return df.with_columns(srs_name=srs) + return srs + + def _df_rename_columns( + self, df: pl.DataFrame, old_columns: list[str], new_columns: list[str] + ) -> pl.DataFrame: + return df.rename(dict(zip(old_columns, new_columns))) + + def _df_reset_index( + self, + df: pl.DataFrame, + index_cols: str | list[str] | None = None, + drop: bool = False, ) -> pl.DataFrame: - return df.filter(pl.col(index_col).is_in(ids).not_()) + if drop: + return df.drop(index_cols) + else: + return df def _df_sample( self, @@ -139,6 +461,16 @@ def _df_sample( n=n, frac=frac, replace=with_replacement, shuffle=shuffle, seed=seed ) + def _df_set_index( + self, + df: pl.DataFrame, + index_name: str, + new_index: Sequence[Hashable] | None = None, + ) -> pl.DataFrame: + if new_index is None: + return df + return df.with_columns(index_name=new_index) + def _srs_constructor( self, data: Sequence[Any] | None = None, @@ -147,3 +479,27 @@ def _srs_constructor( index: Sequence[Any] | None = None, ) -> pl.Series: return pl.Series(name=name, values=data, dtype=self._dtypes_mapping[dtype]) + + def _srs_contains( + self, + srs: Sequence[Any], + values: Any | Sequence[Any], + ) -> pl.Series: + return pl.Series(values, index=values).is_in(srs) + + def _srs_range( + self, + name: str, + start: int, + end: int, + step: int = 1, + ) -> pl.Series: + return pl.arange(start=start, end=end, step=step, eager=True).rename(name) + + def _srs_to_df( + self, srs: pl.Series, index: pl.Series | None = None + ) -> pl.DataFrame: + df = srs.to_frame() + if index: + return df.with_columns({index.name: index}) + return df