From d5201f8afde4f43996e5668a440e0b78393e89cf Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Thu, 4 Jul 2024 20:22:08 +0200 Subject: [PATCH] bug fixes, added tests for AgentsDF --- mesa_frames/abstract/agents.py | 22 +- mesa_frames/abstract/mixin.py | 12 +- mesa_frames/concrete/agents.py | 259 +++++-- mesa_frames/concrete/agentset_pandas.py | 42 +- mesa_frames/concrete/agentset_polars.py | 26 +- mesa_frames/concrete/model.py | 5 +- mesa_frames/types.py | 6 +- tests/test_agents.py | 978 ++++++++++++++++++++++++ tests/test_agentset_pandas.py | 92 ++- tests/test_agentset_polars.py | 86 ++- 10 files changed, 1319 insertions(+), 209 deletions(-) create mode 100644 tests/test_agents.py diff --git a/mesa_frames/abstract/agents.py b/mesa_frames/abstract/agents.py index 4a90d3b..4217648 100644 --- a/mesa_frames/abstract/agents.py +++ b/mesa_frames/abstract/agents.py @@ -3,12 +3,11 @@ from abc import abstractmethod from collections.abc import Callable, Collection, Iterable, Iterator, Sequence from contextlib import suppress +from typing import TYPE_CHECKING, Literal from numpy.random import Generator from typing_extensions import Any, Self, overload -from typing import TYPE_CHECKING, Literal - from mesa_frames.abstract.mixin import CopyMixin from mesa_frames.types import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series @@ -86,7 +85,7 @@ def discard(self, agents, inplace: bool = True) -> Self: ---------- Self """ - with suppress(KeyError): + with suppress(KeyError, ValueError): return self.remove(agents, inplace=inplace) return self._get_obj(inplace) @@ -363,22 +362,20 @@ def sort( def __add__(self, other) -> Self: return self.add(agents=other, inplace=False) - def __contains__(self, id: int) -> bool: + def __contains__(self, agents: int | IdsLike | AgentSetDF) -> bool: """Check if an agent is in the AgentContainer. Parameters ---------- - id : Hashable - The ID(s) to check for. + id : int | IdsLike | AgentSetDF + The ID(s) or AgentSetDF to check for. Returns ------- bool True if the agent is in the AgentContainer, False otherwise. """ - if not isinstance(id, int): - raise TypeError("id must be an integer") - return self.contains(ids=id) + return self.contains(agents=agents) def __getitem__( self, @@ -511,7 +508,7 @@ def __getattr__(self, name: str) -> Any | dict[str, Any]: """ @abstractmethod - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[dict[str, Any]]: """Iterate over the agents in the AgentContainer. Returns @@ -973,15 +970,12 @@ def __getitem__( ), ) -> Series | DataFrame: attr = super().__getitem__(key) - assert isinstance(attr, (Series, DataFrame)) + assert isinstance(attr, (Series, DataFrame, Index)) return attr def __len__(self) -> int: return len(self._agents) - def __iter__(self) -> Iterator: - return iter(self._agents) - def __repr__(self) -> str: return f"{self.__class__.__name__}\n {str(self._agents)}" diff --git a/mesa_frames/abstract/mixin.py b/mesa_frames/abstract/mixin.py index 8d73b84..a62752d 100644 --- a/mesa_frames/abstract/mixin.py +++ b/mesa_frames/abstract/mixin.py @@ -26,7 +26,7 @@ class CopyMixin(ABC): _description_ """ - _copy_with_method: dict[str, tuple[str, list[str]]] + _copy_with_method: dict[str, tuple[str, list[str]]] = {} _copy_only_reference: list[str] = [ "_model", ] @@ -38,6 +38,7 @@ def copy( self, deep: bool = False, memo: dict | None = None, + skip: list[str] | None = None, ) -> Self: """Create a copy of the Class. @@ -48,10 +49,12 @@ def copy( If True, all attributes of the AgentContainer will be recursively copied (except attributes in self._copy_reference_only). If False, only the top-level attributes will be copied. Defaults to False. - memo : dict | None, optional A dictionary used to track already copied objects during deep copy. Defaults to None. + skip : list[str] | None, optional + A list of attribute names to skip during the copy process. + Defaults to None. Returns ------- @@ -61,6 +64,9 @@ def copy( cls = self.__class__ obj = cls.__new__(cls) + if skip is None: + skip = [] + if deep: if not memo: memo = {} @@ -71,6 +77,7 @@ def copy( for k, v in attributes.items() if k not in self._copy_with_method and k not in self._copy_only_reference + and k not in skip ] else: [ @@ -78,6 +85,7 @@ def copy( for k, v in self.__dict__.items() if k not in self._copy_with_method and k not in self._copy_only_reference + and k not in skip ] # Copy attributes with a reference only diff --git a/mesa_frames/concrete/agents.py b/mesa_frames/concrete/agents.py index d103444..14891b8 100644 --- a/mesa_frames/concrete/agents.py +++ b/mesa_frames/concrete/agents.py @@ -5,16 +5,24 @@ import polars as pl from typing_extensions import Any, Self, overload +from typing import TYPE_CHECKING + from mesa_frames.abstract.agents import AgentContainer, AgentSetDF -from mesa_frames.types import DataFrame, IdsLike, MaskLike, Series +from mesa_frames.types import ( + AgnosticMask, + BoolSeries, + DataFrame, + IdsLike, + MaskLike, + Series, +) + +if TYPE_CHECKING: + from mesa_frames.concrete.model import ModelDF class AgentsDF(AgentContainer): _agentsets: list[AgentSetDF] - _copy_with_method: dict[str, tuple[str, list[str]]] = { - "_agentsets": ("copy", []), - } - _backend: str _ids: pl.Series """A collection of AgentSetDFs. All agents of the model are stored here. @@ -24,8 +32,6 @@ class AgentsDF(AgentContainer): The agent sets contained in this collection. _copy_with_method : dict[AgentSetDF, tuple[str, list[str]]] A dictionary of attributes to copy with a specified method and arguments. - _backend : str - The backend used for data operations. Properties ---------- @@ -86,12 +92,13 @@ class AgentsDF(AgentContainer): Get the string representation of the AgentsDF. """ - def __init__(self) -> None: + def __init__(self, model: "ModelDF") -> None: + self._model = model self._agentsets = [] self._ids = pl.Series(name="unique_id", dtype=pl.Int64) def add( - self, agentsets: AgentSetDF | Iterable[AgentSetDF], inplace: bool = True + self, agents: AgentSetDF | Iterable[AgentSetDF], inplace: bool = True ) -> Self: """Add an AgentSetDF to the AgentsDF. @@ -113,7 +120,7 @@ def add( If some agentsets are already present in the AgentsDF or if the IDs are not unique. """ obj = self._get_obj(inplace) - other_list = obj._return_agentsets_list(agentsets) + other_list = obj._return_agentsets_list(agents) if obj._check_agentsets_presence(other_list).any(): raise ValueError("Some agentsets are already present in the AgentsDF.") new_ids = pl.concat( @@ -132,7 +139,7 @@ def contains(self, agents: int | AgentSetDF) -> bool: ... def contains(self, agents: IdsLike | Iterable[AgentSetDF]) -> pl.Series: ... def contains( - self, agents: AgentSetDF | IdsLike | Iterable[AgentSetDF] + self, agents: IdsLike | AgentSetDF | Iterable[AgentSetDF] ) -> bool | pl.Series: if isinstance(agents, AgentSetDF): return self._check_agentsets_presence([agents]).any() @@ -143,6 +150,8 @@ def contains( return self._check_agentsets_presence(list(agents)) else: agents = cast(IdsLike, agents) + if isinstance(agents, int): + return agents in self._ids return pl.Series(agents).is_in(self._ids) @overload @@ -150,7 +159,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, return_results: Literal[False] = False, inplace: bool = True, **kwargs, @@ -161,7 +170,7 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, return_results: Literal[True], inplace: bool = True, **kwargs, @@ -171,12 +180,13 @@ def do( self, method_name: str, *args, - mask: MaskLike | None = None, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, return_results: bool = False, inplace: bool = True, **kwargs, ) -> Self | Any: obj = self._get_obj(inplace) + agentsets_masks = obj._get_bool_masks(mask) if return_results: return { agentset: agentset.do( @@ -187,7 +197,7 @@ def do( **kwargs, inplace=inplace, ) - for agentset in obj._agentsets + for agentset, mask in agentsets_masks.items() } else: obj._agentsets = [ @@ -199,32 +209,40 @@ def do( **kwargs, inplace=inplace, ) - for agentset in obj._agentsets + for agentset, mask in agentsets_masks.items() ] return obj def get( self, - attr_names: str | list[str] | None = None, - mask: MaskLike | None = None, + attr_names: str | Collection[str] | None = None, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, ) -> dict[AgentSetDF, Series] | dict[AgentSetDF, DataFrame]: + agentsets_masks = self._get_bool_masks(mask) return { - agentset: agentset.get(attr_names, mask) for agentset in self._agentsets + agentset: agentset.get(attr_names, mask) + for agentset, mask in agentsets_masks.items() } def remove( self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike, inplace: bool = True ) -> Self: obj = self._get_obj(inplace) - deleted = 0 if isinstance(agents, AgentSetDF): - self._agentsets.remove(agents) + # We have to get the index of the original AgentSetDF because the copy made AgentSetDFs with different hash + id = self._agentsets.index(agents) + obj._agentsets.pop(id) elif isinstance(agents, Iterable) and isinstance( next(iter(agents)), AgentSetDF - ): # Faster than controlling every AgentSetDF - for agentset in iter(agents): - self._agentsets.remove(agentset) # type: ignore (Pylance can't recognize agents as Iterable[AgentSetDF]) + ): + ids = [self._agentsets.index(agentset) for agentset in iter(agents)] + ids.sort(reverse=True) + for id in ids: + obj._agentsets.pop(id) else: # IDsLike + deleted = 0 + if isinstance(agents, int): + agents = [agents] for agentset in obj._agentsets: initial_len = len(agentset) agentset.discard(agents, inplace=True) @@ -235,42 +253,56 @@ def remove( ) return obj - def set( - self, - attr_names: str | dict[AgentSetDF, Any] | Collection[str], - values: Any | None = None, - mask: MaskLike | None = None, - inplace: bool = True, - ) -> Self: - obj = self._get_obj(inplace) - obj._agentsets = [ - agentset.set( - attr_names=attr_names, values=values, mask=mask, inplace=inplace - ) - for agentset in obj._agentsets - ] - return obj - def select( self, - mask: MaskLike | None = None, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, filter_func: Callable[[AgentSetDF], MaskLike] | None = None, n: int | None = None, inplace: bool = True, negate: bool = False, ) -> Self: obj = self._get_obj(inplace) + agentsets_masks = obj._get_bool_masks(mask) + if n is not None: + n = n // len(agentsets_masks) obj._agentsets = [ agentset.select( mask=mask, filter_func=filter_func, n=n, negate=negate, inplace=inplace ) - for agentset in obj._agentsets + for agentset, mask in agentsets_masks.items() ] return obj + def set( + self, + attr_names: str | dict[AgentSetDF, Any] | Collection[str], + values: Any | None = None, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + inplace: bool = True, + ) -> Self: + obj = self._get_obj(inplace) + agentsets_masks = obj._get_bool_masks(mask) + if isinstance(attr_names, dict): + for agentset, values in attr_names.items(): + if not inplace: + # We have to get the index of the original AgentSetDF because the copy made AgentSetDFs with different hash + id = self._agentsets.index(agentset) + agentset = obj._agentsets[id] + agentset.set( + attr_names=values, mask=agentsets_masks[agentset], inplace=True + ) + else: + obj._agentsets = [ + agentset.set( + attr_names=attr_names, values=values, mask=mask, inplace=True + ) + for agentset, mask in agentsets_masks.items() + ] + return obj + def shuffle(self, inplace: bool = True) -> Self: obj = self._get_obj(inplace) - obj._agentsets = [agentset.shuffle(inplace) for agentset in obj._agentsets] + obj._agentsets = [agentset.shuffle(inplace=True) for agentset in obj._agentsets] return obj def sort( @@ -295,27 +327,29 @@ def _check_ids_presence(self, other: list[AgentSetDF]) -> pl.DataFrame: other : list[AgentSetDF] The AgentSetDFs to check. - Raises - ------ - ValueError - If the agent set contains IDs already present in agents. + Returns + ------- + pl.DataFrame + A DataFrame with the unique IDs and a boolean column indicating if they are present. """ presence_df = pl.DataFrame( - data={"unique_id": self._ids}, + data={"unique_id": self._ids, "present": True}, schema={"unique_id": pl.Int64, "present": pl.Boolean}, ) for agentset in other: - new_ids = pl.Series(agentset["unique_id"]) + new_ids = pl.Series(agentset.index) presence_df = pl.concat( [ presence_df, ( new_ids.is_in(presence_df["unique_id"]) - .to_frame() - .with_columns("unique_id", new_ids) + .to_frame("present") + .with_columns(unique_id=new_ids) + .select(["unique_id", "present"]) ), ] ) + presence_df = presence_df.slice(self._ids.len()) return presence_df def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: @@ -336,6 +370,17 @@ def _check_agentsets_presence(self, other: list[AgentSetDF]) -> pl.Series: [agentset in other_set for agentset in self._agentsets], dtype=pl.Boolean ) + def _get_bool_masks( + self, + mask: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] = None, + ) -> dict[AgentSetDF, BoolSeries]: + return_dictionary = {} + if not isinstance(mask, dict): + mask = {agentset: mask for agentset in self._agentsets} + for agentset, mask in mask.items(): + return_dictionary[agentset] = agentset._get_bool_mask(mask) + return return_dictionary + def _return_agentsets_list( self, agentsets: AgentSetDF | Iterable[AgentSetDF] ) -> list[AgentSetDF]: @@ -367,8 +412,39 @@ def __add__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: return super().__add__(other) def __getattr__(self, name: str) -> dict[AgentSetDF, Any]: + if name.startswith("_"): # Avoids infinite recursion of private attributes + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) return {agentset: getattr(agentset, name) for agentset in self._agentsets} + @overload + def __getitem__( + self, key: str | tuple[dict[AgentSetDF, MaskLike], str] + ) -> dict[str, Series]: ... + + @overload + def __getitem__( + self, + key: Collection[str] + | AgnosticMask + | IdsLike + | tuple[dict[AgentSetDF, MaskLike], Collection[str]], + ) -> dict[str, DataFrame]: ... + + def __getitem__( + self, + key: ( + str + | Collection[str] + | AgnosticMask + | IdsLike + | tuple[dict[AgentSetDF, MaskLike], str] + | tuple[dict[AgentSetDF, MaskLike], Collection[str]] + ), + ) -> dict[str, Series] | dict[str, DataFrame]: + return super().__getitem__(key) + def __iadd__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: """Add AgentSetDFs to the AgentsDF through the += operator. @@ -384,17 +460,30 @@ def __iadd__(self, other: AgentSetDF | Iterable[AgentSetDF]) -> Self: """ return super().__iadd__(other) - def __iter__(self) -> Iterator: - return ( - agent for agentset in self._agentsets for agent in iter(agentset._backend) - ) + def __iter__(self) -> Iterator[dict[str, Any]]: + return (agent for agentset in self._agentsets for agent in iter(agentset)) + + def __isub__(self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike) -> Self: + """Remove AgentSetDFs from the AgentsDF through the -= operator. + + Parameters + ---------- + agents : AgentSetDF | Iterable[AgentSetDF] | IdsLike + The AgentSetDFs to remove. + + Returns + ------- + Self + The updated AgentsDF. + """ + return super().__isub__(agents) + + def __len__(self) -> int: + return sum(len(agentset._agents) for agentset in self._agentsets) def __repr__(self) -> str: return "\n".join([repr(agentset) for agentset in self._agentsets]) - def __str__(self) -> str: - return "\n".join([str(agentset) for agentset in self._agentsets]) - def __reversed__(self) -> Iterator: return ( agent @@ -402,8 +491,37 @@ def __reversed__(self) -> Iterator: for agent in reversed(agentset._backend) ) - def __len__(self) -> int: - return sum(len(agentset._agents) for agentset in self._agentsets) + def __setitem__( + self, + key: ( + str + | Collection[str] + | AgnosticMask + | IdsLike + | tuple[dict[AgentSetDF, MaskLike], str] + | tuple[dict[AgentSetDF, MaskLike], Collection[str]] + ), + values: Any, + ) -> None: + super().__setitem__(key, values) + + def __str__(self) -> str: + return "\n".join([str(agentset) for agentset in self._agentsets]) + + def __sub__(self, agents: AgentSetDF | Iterable[AgentSetDF] | IdsLike) -> Self: + """Remove AgentSetDFs from a new AgentsDF through the - operator. + + Parameters + ---------- + other : AgentSetDF | Iterable[AgentSetDF] | IdsLike + The AgentSetDFs to remove. + + Returns + ------- + AgentsDF + A new AgentsDF with the removed AgentSetDFs. + """ + return super().__sub__(agents) @property def agents(self) -> dict[AgentSetDF, DataFrame]: @@ -424,13 +542,26 @@ def agents(self, other: Iterable[AgentSetDF]) -> None: def active_agents(self) -> dict[AgentSetDF, DataFrame]: return {agentset: agentset.active_agents for agentset in self._agentsets} + @active_agents.setter + def active_agents( + self, agents: AgnosticMask | IdsLike | dict[AgentSetDF, MaskLike] + ) -> None: + self.select(agents, inplace=True) + @property - def agentsets_by_type(self) -> dict[type[AgentSetDF], list[AgentSetDF]]: - dictionary = defaultdict(list) + def agentsets_by_type(self) -> dict[type[AgentSetDF], Self]: + def copy_without_agentsets() -> Self: + return self.copy(deep=False, skip=["_agentsets"]) + + dictionary: dict[type[AgentSetDF], Self] = defaultdict(copy_without_agentsets) + for agentset in self._agentsets: - dictionary[agentset.__class__] = dictionary[agentset.__class__] + [agentset] + agents_df = dictionary[agentset.__class__] + agents_df._agentsets = [] + agents_df._agentsets = agents_df._agentsets + [agentset] + dictionary[agentset.__class__] = agents_df return dictionary @property - def inactive_agents(self): + def inactive_agents(self) -> dict[AgentSetDF, DataFrame]: return {agentset: agentset.inactive_agents for agentset in self._agentsets} diff --git a/mesa_frames/concrete/agentset_pandas.py b/mesa_frames/concrete/agentset_pandas.py index ee4542b..8f4ce4d 100644 --- a/mesa_frames/concrete/agentset_pandas.py +++ b/mesa_frames/concrete/agentset_pandas.py @@ -1,11 +1,10 @@ from collections.abc import Callable, Collection, Iterable, Iterator, Sequence +from typing import TYPE_CHECKING import pandas as pd import polars as pl from typing_extensions import Any, Self, overload -from typing import TYPE_CHECKING - from mesa_frames.abstract.agents import AgentSetDF from mesa_frames.concrete.agentset_polars import AgentSetPolars from mesa_frames.types import PandasIdsLike, PandasMaskLike @@ -152,22 +151,22 @@ def add( return obj @overload - def contains(self, ids: int) -> bool: ... + def contains(self, agents: int) -> bool: ... @overload - def contains(self, ids: PandasIdsLike) -> pd.Series: ... + def contains(self, agents: PandasIdsLike) -> pd.Series: ... - def contains(self, ids: PandasIdsLike) -> bool | pd.Series: - if isinstance(ids, pd.Series): - return ids.isin(self._agents.index) - elif isinstance(ids, pd.Index): + def contains(self, agents: PandasIdsLike) -> bool | pd.Series: + if isinstance(agents, pd.Series): + return agents.isin(self._agents.index) + elif isinstance(agents, pd.Index): return pd.Series( - ids.isin(self._agents.index), index=ids, dtype=pd.BooleanDtype() + agents.isin(self._agents.index), index=agents, dtype=pd.BooleanDtype() ) - elif isinstance(ids, Collection): - return pd.Series(list(ids), index=list(ids)).isin(self._agents.index) + elif isinstance(agents, Collection): + return pd.Series(list(agents), index=list(agents)).isin(self._agents.index) else: - return ids in self._agents.index + return agents in self._agents.index def get( self, @@ -339,10 +338,6 @@ def _get_masked_df( if isinstance(mask, pd.Series) and mask.dtype == bool: return self._agents.loc[mask] elif isinstance(mask, pd.DataFrame): - if not mask.index.isin(self._agents.index).all(): - raise KeyError( - "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." - ) if mask.index.name != "unique_id": if "unique_id" in mask.columns: mask.set_index("unique_id", inplace=True, drop=True) @@ -352,10 +347,6 @@ def _get_masked_df( self._agents, on="unique_id", how="left" ) elif isinstance(mask, pd.Series): - if not mask.isin(self._agents.index).all(): - raise KeyError( - "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." - ) mask_df = mask.to_frame("unique_id").set_index("unique_id") return mask_df.join(self._agents, on="unique_id", how="left") elif mask is None or mask == "all": @@ -364,10 +355,6 @@ def _get_masked_df( return self._agents.loc[self._mask] else: mask_series = pd.Series(mask) - if not mask_series.isin(self._agents.index).all(): - raise KeyError( - "Some 'unique_id' of mask are not present in DataFrame 'unique_id'." - ) mask_df = mask_series.to_frame("unique_id").set_index("unique_id") return mask_df.join(self._agents, on="unique_id", how="left") @@ -409,8 +396,11 @@ def __getattr__(self, name: str) -> Any: super().__getattr__(name) return getattr(self._agents, name) - def __iter__(self) -> Iterator: - return iter(self._agents.iterrows()) + def __iter__(self) -> Iterator[dict[str, Any]]: + for index, row in self._agents.iterrows(): + row_dict = row.to_dict() + row_dict["unique_id"] = index + yield row_dict def __len__(self) -> int: return len(self._agents) diff --git a/mesa_frames/concrete/agentset_polars.py b/mesa_frames/concrete/agentset_polars.py index 6646ece..099f72f 100644 --- a/mesa_frames/concrete/agentset_polars.py +++ b/mesa_frames/concrete/agentset_polars.py @@ -1,11 +1,10 @@ from collections.abc import Callable, Collection, Iterable, Iterator, Sequence +from typing import TYPE_CHECKING import polars as pl from polars.type_aliases import IntoExpr from typing_extensions import Any, Self, overload -from typing import TYPE_CHECKING - from mesa_frames.concrete.agents import AgentSetDF from mesa_frames.types import PolarsIdsLike, PolarsMaskLike @@ -169,21 +168,21 @@ def add( return obj @overload - def contains(self, ids: int) -> bool: ... + def contains(self, agents: int) -> bool: ... @overload - def contains(self, ids: PolarsIdsLike) -> pl.Series: ... + def contains(self, agents: PolarsIdsLike) -> pl.Series: ... def contains( self, - ids: PolarsIdsLike, + agents: PolarsIdsLike, ) -> bool | pl.Series: - if isinstance(ids, pl.Series): - return ids.is_in(self._agents["unique_id"]) - elif isinstance(ids, Collection): - return pl.Series(ids).is_in(self._agents["unique_id"]) + if isinstance(agents, pl.Series): + return agents.is_in(self._agents["unique_id"]) + elif isinstance(agents, Collection): + return pl.Series(agents).is_in(self._agents["unique_id"]) else: - return ids in self._agents["unique_id"] + return agents in self._agents["unique_id"] def get( self, @@ -397,11 +396,6 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series: and len(mask) == len(self._agents) ): return mask - else: - if not mask.is_in(self._agents["unique_id"]).all(): - raise KeyError( - "Some 'unique_ids' of mask are not present in DataFrame 'unique_id'." - ) return self._agents["unique_id"].is_in(mask) if isinstance(mask, pl.Expr): @@ -524,7 +518,7 @@ def __getitem__( assert isinstance(attr, (pl.Series, pl.DataFrame)) return attr - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[dict[str, Any]]: return iter(self._agents.iter_rows(named=True)) def __len__(self) -> int: diff --git a/mesa_frames/concrete/model.py b/mesa_frames/concrete/model.py index c31d31a..be4ef43 100644 --- a/mesa_frames/concrete/model.py +++ b/mesa_frames/concrete/model.py @@ -1,10 +1,9 @@ from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np from typing_extensions import Any -from typing import TYPE_CHECKING - from mesa_frames.concrete.agents import AgentsDF if TYPE_CHECKING: @@ -77,7 +76,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.running = True self.schedule = None self.current_id = 0 - self._agents = AgentsDF() + self._agents = AgentsDF(self) def get_agents_of_type(self, agent_type: type) -> "AgentSetDF": """Retrieve the AgentSetDF of a specified type. diff --git a/mesa_frames/types.py b/mesa_frames/types.py index fcef6d6..232aa0f 100644 --- a/mesa_frames/types.py +++ b/mesa_frames/types.py @@ -1,12 +1,12 @@ -from collections.abc import Collection, Hashable +from collections.abc import Collection +from typing import Literal import pandas as pd import polars as pl from numpy import ndarray -from typing import Literal ####----- Agnostic Types -----#### -AgnosticMask = Literal["all", "active"] | Hashable | None +AgnosticMask = Literal["all", "active"] | None AgnosticIds = int | Collection[int] ###----- Pandas Types -----### diff --git a/tests/test_agents.py b/tests/test_agents.py new file mode 100644 index 0000000..eda4ae2 --- /dev/null +++ b/tests/test_agents.py @@ -0,0 +1,978 @@ +from copy import copy, deepcopy + +import pandas as pd +import polars as pl +import pytest + +from mesa_frames import AgentsDF, ModelDF +from mesa_frames.abstract.agents import AgentSetDF +from mesa_frames.types import MaskLike +from tests.test_agentset_pandas import ( + ExampleAgentSetPandas, + fix1_AgentSetPandas, + fix2_AgentSetPandas, +) +from tests.test_agentset_polars import ( + ExampleAgentSetPolars, + fix2_AgentSetPolars, +) + + +# This serves otherwise ruff complains about the two fixtures not being used +def not_called(): + fix1_AgentSetPandas() + fix2_AgentSetPandas() + fix2_AgentSetPolars() + + +@pytest.fixture +def fix_AgentsDF( + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, +) -> AgentsDF: + model = ModelDF() + agents = AgentsDF(model) + agents.add([fix1_AgentSetPandas, fix2_AgentSetPolars]) + return agents + + +class Test_AgentsDF: + def test___init__(self): + model = ModelDF() + agents = AgentsDF(model) + assert agents.model == model + assert isinstance(agents._agentsets, list) + assert len(agents._agentsets) == 0 + assert isinstance(agents._ids, pl.Series) + assert agents._ids.is_empty() + assert agents._ids.name == "unique_id" + + def test_add( + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + model = ModelDF() + agents = AgentsDF(model) + agentset_pandas = fix1_AgentSetPandas + agentset_polars = fix2_AgentSetPolars + + # Test with a single AgentSetPandas + result = agents.add(agentset_pandas, inplace=False) + assert result._agentsets[0] is agentset_pandas + assert result._ids.to_list() == agentset_pandas._agents.index.to_list() + + # Test with a single AgentSetPolars + result = agents.add(agentset_polars, inplace=False) + assert result._agentsets[0] is agentset_polars + assert result._ids.to_list() == agentset_polars._agents["unique_id"].to_list() + + # Test with a list of AgentSetDFs + result = agents.add([agentset_pandas, agentset_polars], inplace=True) + assert result._agentsets[0] is agentset_pandas + assert result._agentsets[1] is agentset_polars + assert ( + result._ids.to_list() + == agentset_pandas._agents.index.to_list() + + agentset_polars._agents["unique_id"].to_list() + ) + + # Test if adding the same AgentSetDF raises ValueError + with pytest.raises(ValueError): + agents.add(agentset_pandas, inplace=False) + + def test_contains( + self, fix2_AgentSetPandas: ExampleAgentSetPandas, fix_AgentsDF: AgentsDF + ): + agents = fix_AgentsDF + agentset_pandas = agents._agentsets[0] + + # Test with an AgentSetDF + assert agents.contains(agentset_pandas) + + # Test with an AgentSetDF not present + assert not agents.contains(fix2_AgentSetPandas) + + # Test with an iterable of AgentSetDFs + assert agents.contains([agentset_pandas, fix2_AgentSetPandas]).to_list() == [ + True, + False, + ] + + # Test with single id + assert agents.contains(0) + + # Test with a list of ids + assert agents.contains([0, 10]).to_list() == [True, False] + + def test_copy(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + agents.test_list = [[1, 2, 3]] + + # Test with deep=False + agents2 = agents.copy(deep=False) + agents2.test_list[0].append(4) + assert agents.test_list[0][-1] == agents2.test_list[0][-1] + assert agents.model == agents2.model + assert agents._agentsets[0] == agents2._agentsets[0] + assert (agents._ids == agents2._ids).all() + + # Test with deep=True + agents2 = fix_AgentsDF.copy(deep=True) + agents2.test_list[0].append(4) + assert agents.test_list[-1] != agents2.test_list[-1] + assert agents.model == agents2.model + assert agents._agentsets[0] != agents2._agentsets[0] + assert (agents._ids == agents2._ids).all() + + def test_discard( + self, fix_AgentsDF: AgentsDF, fix2_AgentSetPandas: ExampleAgentSetPandas + ): + agents = fix_AgentsDF + # Test with a single AgentSetDF + agentset_polars = agents._agentsets[1] + result = agents.discard(agents._agentsets[0], inplace=False) + assert isinstance(result._agentsets[0], ExampleAgentSetPolars) + assert len(result._agentsets) == 1 + + # Test with a list of AgentSetDFs + result = agents.discard(agents._agentsets.copy(), inplace=False) + assert len(result._agentsets) == 0 + + # Test with IDs + ids = [ + agents._agentsets[0]._agents.index[0], + agents._agentsets[1]._agents["unique_id"][0], + ] + agentset_pandas = agents._agentsets[0] + agentset_polars = agents._agentsets[1] + result = agents.discard(ids, inplace=False) + assert result._agentsets[0].index[0] == agentset_pandas._agents.index[1] + assert ( + result._agentsets[1].agents["unique_id"][0] + == agentset_polars._agents["unique_id"][1] + ) + + # Test if removing an AgentSetDF not present raises ValueError + result = agents.discard(fix2_AgentSetPandas, inplace=False) + + # Test if removing an ID not present raises KeyError + assert -100 not in agents._ids + result = agents.discard(-100, inplace=False) + + def test_do(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + + expected_result_0 = agents._agentsets[0].agents["wealth"] + expected_result_0 += 1 + + expected_result_1 = agents._agentsets[1].agents["wealth"] + expected_result_1 += 1 + + # Test with no return_results, no mask, inplace + agents.do("add_wealth", 1) + assert ( + agents._agentsets[0].agents["wealth"].to_list() + == expected_result_0.to_list() + ) + assert ( + agents._agentsets[1].agents["wealth"].to_list() + == expected_result_1.to_list() + ) + + # Test with return_results=True, no mask, inplace + expected_result_0 = agents._agentsets[0].agents["wealth"] + expected_result_0 += 1 + + expected_result_1 = agents._agentsets[1].agents["wealth"] + expected_result_1 += 1 + assert agents.do("add_wealth", 1, return_results=True) == { + agents._agentsets[0]: None, + agents._agentsets[1]: None, + } + assert ( + agents._agentsets[0].agents["wealth"].to_list() + == expected_result_0.to_list() + ) + assert ( + agents._agentsets[1].agents["wealth"].to_list() + == expected_result_1.to_list() + ) + + # Test with a mask, inplace + mask0 = ( + agents._agentsets[0].agents["wealth"] > 10 + ) # No agent should be selected + mask1 = ( + agents._agentsets[1].agents["wealth"] > 10 + ) # All agents should be selected + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + + expected_result_0 = agents._agentsets[0].agents["wealth"] + expected_result_1 = agents._agentsets[1].agents["wealth"] + expected_result_1 += 1 + + agents.do("add_wealth", 1, mask=mask_dictionary) + assert ( + agents._agentsets[0].agents["wealth"].to_list() + == expected_result_0.to_list() + ) + assert ( + agents._agentsets[1].agents["wealth"].to_list() + == expected_result_1.to_list() + ) + + def test_get( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + agents = fix_AgentsDF + + # Test with a single attribute + assert ( + agents.get("wealth")[fix1_AgentSetPandas].to_list() + == fix1_AgentSetPandas._agents["wealth"].to_list() + ) + assert ( + agents.get("wealth")[fix2_AgentSetPolars].to_list() + == fix2_AgentSetPolars._agents["wealth"].to_list() + ) + + # Test with a list of attributes + result = agents.get(["wealth", "age"]) + assert result[fix1_AgentSetPandas].columns.to_list() == ["wealth", "age"] + assert ( + result[fix1_AgentSetPandas]["wealth"].to_list() + == fix1_AgentSetPandas._agents["wealth"].to_list() + ) + assert ( + result[fix1_AgentSetPandas]["age"].to_list() + == fix1_AgentSetPandas._agents["age"].to_list() + ) + assert result[fix2_AgentSetPolars].columns == ["wealth", "age"] + assert ( + result[fix2_AgentSetPolars]["wealth"].to_list() + == fix2_AgentSetPolars._agents["wealth"].to_list() + ) + assert ( + result[fix2_AgentSetPolars]["age"].to_list() + == fix2_AgentSetPolars._agents["age"].to_list() + ) + + # Test with a single attribute and a mask + mask0 = ( + fix1_AgentSetPandas._agents["wealth"] + > fix1_AgentSetPandas._agents["wealth"][0] + ) + mask1 = ( + fix2_AgentSetPolars._agents["wealth"] + > fix2_AgentSetPolars._agents["wealth"][0] + ) + mask_dictionary = {fix1_AgentSetPandas: mask0, fix2_AgentSetPolars: mask1} + result = agents.get("wealth", mask=mask_dictionary) + assert ( + result[fix1_AgentSetPandas].to_list() + == fix1_AgentSetPandas._agents["wealth"].to_list()[1:] + ) + assert ( + result[fix2_AgentSetPolars].to_list() + == fix2_AgentSetPolars._agents["wealth"].to_list()[1:] + ) + + def test_remove( + self, + fix_AgentsDF: AgentsDF, + fix2_AgentSetPandas: ExampleAgentSetPandas, + ): + agents = fix_AgentsDF + + # Test with a single AgentSetDF + agentset_polars = agents._agentsets[1] + result = agents.remove(agents._agentsets[0], inplace=False) + assert isinstance(result._agentsets[0], ExampleAgentSetPolars) + assert len(result._agentsets) == 1 + + # Test with a list of AgentSetDFs + result = agents.remove(agents._agentsets.copy(), inplace=False) + assert len(result._agentsets) == 0 + + # Test with IDs + ids = [ + agents._agentsets[0]._agents.index[0], + agents._agentsets[1]._agents["unique_id"][0], + ] + agentset_pandas = agents._agentsets[0] + agentset_polars = agents._agentsets[1] + result = agents.remove(ids, inplace=False) + assert result._agentsets[0].index[0] == agentset_pandas._agents.index[1] + assert ( + result._agentsets[1].agents["unique_id"][0] + == agentset_polars._agents["unique_id"][1] + ) + + # Test if removing an AgentSetDF not present raises ValueError + with pytest.raises(ValueError): + result = agents.remove(fix2_AgentSetPandas, inplace=False) + + # Test if removing an ID not present raises KeyError + assert -100 not in agents._ids + with pytest.raises(KeyError): + result = agents.remove(-100, inplace=False) + + def test_select(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + + def compare_dataframes(df1, df2): + if isinstance(df1, pd.DataFrame) and isinstance(df2, pd.DataFrame): + # For pandas DataFrames + return df1.equals(df2) + elif isinstance(df1, pl.DataFrame) and isinstance(df2, pl.DataFrame): + # For polars DataFrames + return df1.frame_equal(df2) + else: + # If the types are not the same, they are not equal + return False + + # Test with default arguments. Should select all agents + selected = agents.select(inplace=False) + active_agents_dict = selected.active_agents + agents_dict = selected.agents + assert active_agents_dict.keys() == agents_dict.keys() + # Using assert to compare all DataFrames in the dictionaries + assert all( + compare_dataframes(df_active, df_all) + for df_active, df_all in zip( + active_agents_dict.values(), agents_dict.values() + ) + ) + + # Test with a mask + mask0 = pd.Series( + [True, False, True, True], index=agents._agentsets[0].index, dtype=bool + ) + mask1 = pl.Series("mask", [True, False, True, True], dtype=pl.Boolean) + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + selected = agents.select(mask_dictionary, inplace=False) + assert ( + selected.active_agents[selected._agentsets[0]]["wealth"].to_list()[0] + == agents._agentsets[0]["wealth"].to_list()[0] + ) + assert ( + selected.active_agents[selected._agentsets[0]]["wealth"].to_list()[-1] + == agents._agentsets[0]["wealth"].to_list()[-1] + ) + assert ( + selected.active_agents[selected._agentsets[1]]["wealth"].to_list()[0] + == agents._agentsets[1]["wealth"].to_list()[0] + ) + assert ( + selected.active_agents[selected._agentsets[1]]["wealth"].to_list()[-1] + == agents._agentsets[1]["wealth"].to_list()[-1] + ) + + # Test with filter_func + def filter_func(agentset: AgentSetDF) -> pl.Series: + return agentset.agents["wealth"] > agentset.agents["wealth"][0] + + selected = agents.select(filter_func=filter_func, inplace=False) + assert ( + selected.active_agents[selected._agentsets[0]]["wealth"].to_list() + == agents._agentsets[0]["wealth"].to_list()[1:] + ) + assert ( + selected.active_agents[selected._agentsets[1]]["wealth"].to_list() + == agents._agentsets[1]["wealth"].to_list()[1:] + ) + + # Test with n + selected = agents.select(n=3, inplace=False) + assert sum(len(df) for df in selected.active_agents.values()) in [2, 3] + + # Test with n, filter_func and mask + selected = agents.select( + mask_dictionary, filter_func=filter_func, n=2, inplace=False + ) + assert any( + el in selected.active_agents[selected._agentsets[0]]["wealth"].to_list() + for el in agents.active_agents[agents._agentsets[0]]["wealth"].to_list()[ + 2:4 + ] + ) + assert any( + el in selected.active_agents[selected._agentsets[1]]["wealth"].to_list() + for el in agents.active_agents[agents._agentsets[1]]["wealth"].to_list()[ + 2:4 + ] + ) + + def test_set(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + + # Test with a single attribute + result = agents.set("wealth", 0, inplace=False) + assert result._agentsets[0].agents["wealth"].to_list() == [0] * len( + agents._agentsets[0] + ) + assert result._agentsets[1].agents["wealth"].to_list() == [0] * len( + agents._agentsets[1] + ) + + # Test with a list of attributes + agents.set(["wealth", "age"], 1, inplace=True) + assert agents._agentsets[0].agents["wealth"].to_list() == [1] * len( + agents._agentsets[0] + ) + assert agents._agentsets[0].agents["age"].to_list() == [1] * len( + agents._agentsets[0] + ) + + # Test with a single attribute and a mask + mask0 = pd.Series( + [True] + [False] * (len(agents._agentsets[0]) - 1), + index=agents._agentsets[0].index, + dtype=bool, + ) + mask1 = pl.Series( + "mask", [True] + [False] * (len(agents._agentsets[1]) - 1), dtype=pl.Boolean + ) + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + result = agents.set("wealth", 0, mask=mask_dictionary, inplace=False) + assert result._agentsets[0].agents["wealth"].to_list() == [0] + [1] * ( + len(agents._agentsets[0]) - 1 + ) + assert result._agentsets[1].agents["wealth"].to_list() == [0] + [1] * ( + len(agents._agentsets[1]) - 1 + ) + + # Test with a dictionary + agents.set( + {agents._agentsets[0]: {"wealth": 0}, agents._agentsets[1]: {"wealth": 1}}, + inplace=True, + ) + assert agents._agentsets[0].agents["wealth"].to_list() == [0] * len( + agents._agentsets[0] + ) + assert agents._agentsets[1].agents["wealth"].to_list() == [1] * len( + agents._agentsets[1] + ) + + def test_shuffle(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + for _ in range(100): + original_order_0 = agents._agentsets[0].agents.index.to_list() + original_order_1 = agents._agentsets[1].agents["unique_id"].to_list() + agents.shuffle(inplace=True) + if ( + original_order_0 != agents._agentsets[0].agents.index.to_list() + and original_order_1 + != agents._agentsets[1].agents["unique_id"].to_list() + ): + return + assert False + + def test_sort(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + agents.sort("wealth", ascending=False, inplace=True) + assert pl.Series(agents._agentsets[0].agents["wealth"]).is_sorted( + descending=True + ) + assert pl.Series(agents._agentsets[1].agents["wealth"]).is_sorted( + descending=True + ) + + def test__check_ids_presence( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + ): + agents = fix_AgentsDF + agents_different_index = deepcopy(fix1_AgentSetPandas) + agents_different_index._agents.index = [-100, -200, -300, -400] + result = agents._check_ids_presence([fix1_AgentSetPandas]) + assert result.filter( + pl.col("unique_id").is_in(fix1_AgentSetPandas._agents.index) + )["present"].all() + assert not result.filter( + pl.col("unique_id").is_in(agents_different_index._agents.index) + )["present"].any() + + def test__check_agentsets_presence( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, + ): + agents = fix_AgentsDF + result = agents._check_agentsets_presence( + [fix1_AgentSetPandas, fix2_AgentSetPandas] + ) + assert result[0] + assert not result[1] + + def test__get_bool_masks(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + # Test with mask = None + result = agents._get_bool_masks(mask=None) + truth_value = True + for i, mask in enumerate(result.values()): + if isinstance(mask, pl.Expr): + mask = agents._agentsets[i]._agents.select(mask).to_series() + truth_value &= mask.all() + assert truth_value + + # Test with mask = "all" + result = agents._get_bool_masks(mask="all") + truth_value = True + for i, mask in enumerate(result.values()): + if isinstance(mask, pl.Expr): + mask = agents._agentsets[i]._agents.select(mask).to_series() + truth_value &= mask.all() + assert truth_value + + # Test with mask = "active" + mask0 = ( + agents._agentsets[0].agents["wealth"] + > agents._agentsets[0].agents["wealth"][0] + ) + mask1 = ( + agents._agentsets[1].agents["wealth"] + > agents._agentsets[1].agents["wealth"][0] + ) + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + agents.select(mask=mask_dictionary) + result = agents._get_bool_masks(mask="active") + assert result[agents._agentsets[0]].to_list() == mask0.to_list() + assert result[agents._agentsets[1]].to_list() == mask1.to_list() + + # Test with mask = IdsLike + result = agents._get_bool_masks( + mask=[ + agents._agentsets[0].index[0], + agents._agentsets[1].agents["unique_id"][0], + ] + ) + assert result[agents._agentsets[0]].to_list() == [True] + [False] * ( + len(agents._agentsets[0]) - 1 + ) + assert result[agents._agentsets[1]].to_list() == [True] + [False] * ( + len(agents._agentsets[1]) - 1 + ) + + # Test with mask = dict[AgentSetDF, MaskLike] + result = agents._get_bool_masks(mask=mask_dictionary) + assert result[agents._agentsets[0]].to_list() == mask0.to_list() + assert result[agents._agentsets[1]].to_list() == mask1.to_list() + + def test__get_obj(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + assert agents._get_obj(inplace=True) is agents + assert agents._get_obj(inplace=False) is not agents + + def test__return_agentsets_list( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, + ): + agents = fix_AgentsDF + result = agents._return_agentsets_list(fix1_AgentSetPandas) + assert result == [fix1_AgentSetPandas] + result = agents._return_agentsets_list( + [fix1_AgentSetPandas, fix2_AgentSetPandas] + ) + assert result == [fix1_AgentSetPandas, fix2_AgentSetPandas] + + def test___add__( + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + model = ModelDF() + agents = AgentsDF(model) + agentset_pandas = fix1_AgentSetPandas + agentset_polars = fix2_AgentSetPolars + + # Test with a single AgentSetPandas + result = agents + agentset_pandas + assert result._agentsets[0] is agentset_pandas + assert result._ids.to_list() == agentset_pandas._agents.index.to_list() + + # Test with a single AgentSetPolars + result = agents + agentset_polars + assert result._agentsets[0] is agentset_polars + assert result._ids.to_list() == agentset_polars._agents["unique_id"].to_list() + + # Test with a list of AgentSetDFs + result = agents + [agentset_pandas, agentset_polars] + assert result._agentsets[0] is agentset_pandas + assert result._agentsets[1] is agentset_polars + assert ( + result._ids.to_list() + == agentset_pandas._agents.index.to_list() + + agentset_polars._agents["unique_id"].to_list() + ) + + # Test if adding the same AgentSetDF raises ValueError + with pytest.raises(ValueError): + result + agentset_pandas + + def test___contains__( + self, fix_AgentsDF: AgentsDF, fix2_AgentSetPandas: ExampleAgentSetPandas + ): + # Test with a single value + agents = fix_AgentsDF + agentset_pandas = agents._agentsets[0] + + # Test with an AgentSetDF + assert agentset_pandas in agents + # Test with an AgentSetDF not present + assert fix2_AgentSetPandas not in agents + + # Test with single id present + assert 0 in agents + + # Test with single id not present + assert 10 not in agents + + def test___copy__(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + agents.test_list = [[1, 2, 3]] + + # Test with deep=False + agents2 = copy(agents) + agents2.test_list[0].append(4) + assert agents.test_list[0][-1] == agents2.test_list[0][-1] + assert agents.model == agents2.model + assert agents._agentsets[0] == agents2._agentsets[0] + assert (agents._ids == agents2._ids).all() + + def test___deepcopy__(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + agents.test_list = [[1, 2, 3]] + + agents2 = deepcopy(agents) + agents2.test_list[0].append(4) + assert agents.test_list[-1] != agents2.test_list[-1] + assert agents.model == agents2.model + assert agents._agentsets[0] != agents2._agentsets[0] + assert (agents._ids == agents2._ids).all() + + def test___getattr__(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + assert isinstance(agents.model, ModelDF) + result = agents.wealth + assert ( + result[agents._agentsets[0]].to_list() + == agents._agentsets[0].agents["wealth"].to_list() + ) + assert ( + result[agents._agentsets[1]].to_list() + == agents._agentsets[1].agents["wealth"].to_list() + ) + + def test___getitem__( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + agents = fix_AgentsDF + + # Test with a single attribute + assert ( + agents["wealth"][fix1_AgentSetPandas].to_list() + == fix1_AgentSetPandas._agents["wealth"].to_list() + ) + assert ( + agents["wealth"][fix2_AgentSetPolars].to_list() + == fix2_AgentSetPolars._agents["wealth"].to_list() + ) + + # Test with a list of attributes + result = agents[["wealth", "age"]] + assert result[fix1_AgentSetPandas].columns.to_list() == ["wealth", "age"] + assert ( + result[fix1_AgentSetPandas]["wealth"].to_list() + == fix1_AgentSetPandas._agents["wealth"].to_list() + ) + assert ( + result[fix1_AgentSetPandas]["age"].to_list() + == fix1_AgentSetPandas._agents["age"].to_list() + ) + assert result[fix2_AgentSetPolars].columns == ["wealth", "age"] + assert ( + result[fix2_AgentSetPolars]["wealth"].to_list() + == fix2_AgentSetPolars._agents["wealth"].to_list() + ) + assert ( + result[fix2_AgentSetPolars]["age"].to_list() + == fix2_AgentSetPolars._agents["age"].to_list() + ) + + # Test with a single attribute and a mask + mask0 = ( + fix1_AgentSetPandas._agents["wealth"] + > fix1_AgentSetPandas._agents["wealth"][0] + ) + mask1 = ( + fix2_AgentSetPolars._agents["wealth"] + > fix2_AgentSetPolars._agents["wealth"][0] + ) + mask_dictionary: dict[AgentSetDF, MaskLike] = { + fix1_AgentSetPandas: mask0, + fix2_AgentSetPolars: mask1, + } + result = agents[mask_dictionary, "wealth"] + assert ( + result[fix1_AgentSetPandas].to_list() + == fix1_AgentSetPandas.agents["wealth"].to_list()[1:] + ) + assert ( + result[fix2_AgentSetPolars].to_list() + == fix2_AgentSetPolars.agents["wealth"].to_list()[1:] + ) + + def test___iadd__( + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + model = ModelDF() + agents = AgentsDF(model) + agentset_pandas = fix1_AgentSetPandas + agentset_polars = fix2_AgentSetPolars + + # Test with a single AgentSetPandas + agents_copy = deepcopy(agents) + agents_copy += agentset_pandas + assert agents_copy._agentsets[0] is agentset_pandas + assert agents_copy._ids.to_list() == agentset_pandas._agents.index.to_list() + + # Test with a single AgentSetPolars + agents_copy = deepcopy(agents) + agents_copy += agentset_polars + assert agents_copy._agentsets[0] is agentset_polars + assert ( + agents_copy._ids.to_list() == agentset_polars._agents["unique_id"].to_list() + ) + + # Test with a list of AgentSetDFs + agents_copy = deepcopy(agents) + agents_copy += [agentset_pandas, agentset_polars] + assert agents_copy._agentsets[0] is agentset_pandas + assert agents_copy._agentsets[1] is agentset_polars + assert ( + agents_copy._ids.to_list() + == agentset_pandas._agents.index.to_list() + + agentset_polars._agents["unique_id"].to_list() + ) + + # Test if adding the same AgentSetDF raises ValueError + with pytest.raises(ValueError): + agents_copy += agentset_pandas + + def test___iter__(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + len_agentset0 = len(agents._agentsets[0]) + len_agentset1 = len(agents._agentsets[1]) + for i, agent in enumerate(agents): + assert isinstance(agent, dict) + if i < len_agentset0: + assert agent["unique_id"] == agents._agentsets[0].agents.index[i] + else: + assert ( + agent["unique_id"] + == agents._agentsets[1].agents["unique_id"][i - len_agentset0] + ) + assert i == len_agentset0 + len_agentset1 - 1 + + def test___isub__( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + # Test with an AgentSetPolars and a DataFrame + agents = fix_AgentsDF + agents -= fix1_AgentSetPandas + assert agents._agentsets[0] == fix2_AgentSetPolars + assert len(agents._agentsets) == 1 + + def test___len__( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + assert len(fix_AgentsDF) == len(fix1_AgentSetPandas) + len(fix2_AgentSetPolars) + + def test___repr__(self, fix_AgentsDF: AgentsDF): + repr(fix_AgentsDF) + + def test___reversed__(self, fix2_AgentSetPolars: AgentsDF): + agents = fix2_AgentSetPolars + reversed_wealth = [] + for agent in reversed(list(agents)): + reversed_wealth.append(agent["wealth"]) + assert reversed_wealth == list(reversed(agents["wealth"])) + + def test___setitem__(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + + # Test with a single attribute + agents["wealth"] = 0 + assert agents._agentsets[0].agents["wealth"].to_list() == [0] * len( + agents._agentsets[0] + ) + assert agents._agentsets[1].agents["wealth"].to_list() == [0] * len( + agents._agentsets[1] + ) + + # Test with a list of attributes + agents[["wealth", "age"]] = 1 + assert agents._agentsets[0].agents["wealth"].to_list() == [1] * len( + agents._agentsets[0] + ) + assert agents._agentsets[0].agents["age"].to_list() == [1] * len( + agents._agentsets[0] + ) + + # Test with a single attribute and a mask + mask0 = pd.Series( + [True] + [False] * (len(agents._agentsets[0]) - 1), + index=agents._agentsets[0].index, + dtype=bool, + ) + mask1 = pl.Series( + "mask", [True] + [False] * (len(agents._agentsets[1]) - 1), dtype=pl.Boolean + ) + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + agents[mask_dictionary, "wealth"] = 0 + assert agents._agentsets[0].agents["wealth"].to_list() == [0] + [1] * ( + len(agents._agentsets[0]) - 1 + ) + assert agents._agentsets[1].agents["wealth"].to_list() == [0] + [1] * ( + len(agents._agentsets[1]) - 1 + ) + + def test___str__(self, fix_AgentsDF: AgentsDF): + str(fix_AgentsDF) + + def test___sub__( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + # Test with an AgentSetPolars and a DataFrame + result = fix_AgentsDF - fix1_AgentSetPandas + assert isinstance(result._agentsets[0], ExampleAgentSetPolars) + assert len(result._agentsets) == 1 + + def test_agents( + self, + fix_AgentsDF: AgentsDF, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPolars: ExampleAgentSetPolars, + ): + assert isinstance(fix_AgentsDF.agents, dict) + assert len(fix_AgentsDF.agents) == 2 + assert fix_AgentsDF.agents[fix1_AgentSetPandas] is fix1_AgentSetPandas._agents + assert fix_AgentsDF.agents[fix2_AgentSetPolars] is fix2_AgentSetPolars._agents + + # Test agents.setter + fix_AgentsDF.agents = [fix1_AgentSetPandas, fix2_AgentSetPandas] + assert fix_AgentsDF._agentsets[0] == fix1_AgentSetPandas + assert fix_AgentsDF._agentsets[1] == fix2_AgentSetPandas + + def test_active_agents(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + + # Test with select + mask0 = ( + agents._agentsets[0].agents["wealth"] + > agents._agentsets[0].agents["wealth"][0] + ) + mask1 = ( + agents._agentsets[1].agents["wealth"] + > agents._agentsets[1].agents["wealth"][0] + ) + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + agents1 = agents.select(mask=mask_dictionary, inplace=False) + result = agents1.active_agents + assert isinstance(result, dict) + assert isinstance(result[agents1._agentsets[0]], pd.DataFrame) + assert isinstance(result[agents1._agentsets[1]], pl.DataFrame) + assert ( + (result[agents1._agentsets[0]] == agents1._agentsets[0]._agents[mask0]) + .all() + .all() + ) + assert all( + series.all() + for series in ( + result[agents1._agentsets[1]] + == agents1._agentsets[1]._agents.filter(mask1) + ) + ) + + # Test with active_agents.setter + agents1.active_agents = mask_dictionary + result = agents1.active_agents + assert isinstance(result, dict) + assert isinstance(result[agents1._agentsets[0]], pd.DataFrame) + assert isinstance(result[agents1._agentsets[1]], pl.DataFrame) + assert ( + (result[agents1._agentsets[0]] == agents1._agentsets[0]._agents[mask0]) + .all() + .all() + ) + assert all( + series.all() + for series in ( + result[agents1._agentsets[1]] + == agents1._agentsets[1]._agents.filter(mask1) + ) + ) + + def test_agentsets_by_type(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + result = agents.agentsets_by_type + assert isinstance(result, dict) + assert isinstance(result[ExampleAgentSetPandas], AgentsDF) + assert isinstance(result[ExampleAgentSetPolars], AgentsDF) + assert result[ExampleAgentSetPandas]._agentsets == [agents._agentsets[0]] + assert result[ExampleAgentSetPolars]._agentsets == [agents._agentsets[1]] + + def test_inactive_agents(self, fix_AgentsDF: AgentsDF): + agents = fix_AgentsDF + + # Test with select + mask0 = ( + agents._agentsets[0].agents["wealth"] + > agents._agentsets[0].agents["wealth"][0] + ) + mask1 = ( + agents._agentsets[1].agents["wealth"] + > agents._agentsets[1].agents["wealth"][0] + ) + mask_dictionary = {agents._agentsets[0]: mask0, agents._agentsets[1]: mask1} + agents1 = agents.select(mask=mask_dictionary, inplace=False) + result = agents1.inactive_agents + assert isinstance(result, dict) + assert isinstance(result[agents1._agentsets[0]], pd.DataFrame) + assert isinstance(result[agents1._agentsets[1]], pl.DataFrame) + assert ( + result[agents1._agentsets[0]] + == agents1._agentsets[0].select(mask0, negate=True).active_agents + ).all().all() + assert all( + series.all() + for series in ( + result[agents1._agentsets[1]] + == agents1._agentsets[1].select(mask1, negate=True).active_agents + ) + ) diff --git a/tests/test_agentset_pandas.py b/tests/test_agentset_pandas.py index 34539f3..4093b5e 100644 --- a/tests/test_agentset_pandas.py +++ b/tests/test_agentset_pandas.py @@ -9,7 +9,7 @@ @tg.typechecked -class ExampleAgentSet(AgentSetPandas): +class ExampleAgentSetPandas(AgentSetPandas): def __init__(self, model: ModelDF, index: pd.Index): super().__init__(model) self.starting_wealth = pd.Series([1, 2, 3, 4], name="wealth", index=index) @@ -19,9 +19,9 @@ def add_wealth(self, amount: int) -> None: @pytest.fixture -def fix1_AgentSetPandas() -> ExampleAgentSet: +def fix1_AgentSetPandas() -> ExampleAgentSetPandas: model = ModelDF() - agents = ExampleAgentSet(model, pd.Index([0, 1, 2, 3], name="unique_id")) + agents = ExampleAgentSetPandas(model, pd.Index([0, 1, 2, 3], name="unique_id")) agents.add({"unique_id": [0, 1, 2, 3]}) agents["wealth"] = agents.starting_wealth agents["age"] = [10, 20, 30, 40] @@ -30,9 +30,9 @@ def fix1_AgentSetPandas() -> ExampleAgentSet: @pytest.fixture -def fix2_AgentSetPandas() -> ExampleAgentSet: +def fix2_AgentSetPandas() -> ExampleAgentSetPandas: model = ModelDF() - agents = ExampleAgentSet(model, pd.Index([4, 5, 6, 7], name="unique_id")) + agents = ExampleAgentSetPandas(model, pd.Index([4, 5, 6, 7], name="unique_id")) agents.add({"unique_id": [4, 5, 6, 7]}) agents["wealth"] = agents.starting_wealth + 10 agents["age"] = [100, 200, 300, 400] @@ -43,7 +43,7 @@ def fix2_AgentSetPandas() -> ExampleAgentSet: class Test_AgentSetPandas: def test__init__(self): model = ModelDF() - agents = ExampleAgentSet(model, pd.Index([0, 1, 2, 3])) + agents = ExampleAgentSetPandas(model, pd.Index([0, 1, 2, 3])) assert agents.model == model assert isinstance(agents.agents, pd.DataFrame) assert agents.agents.index.name == "unique_id" @@ -52,7 +52,9 @@ def test__init__(self): assert agents.starting_wealth.tolist() == [1, 2, 3, 4] def test_add( - self, fix1_AgentSetPandas: ExampleAgentSet, fix2_AgentSetPandas: ExampleAgentSet + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, ): agents = fix1_AgentSetPandas agents2 = fix2_AgentSetPandas @@ -76,7 +78,7 @@ def test_add( assert agents.agents.age.tolist() == [10, 20, 30, 40, 50, 60] assert agents.agents.index.name == "unique_id" - def test_contains(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_contains(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with a single value @@ -86,7 +88,7 @@ def test_contains(self, fix1_AgentSetPandas: ExampleAgentSet): # Test with a list assert agents.contains([0, 1]).values.tolist() == [True, True] - def test_copy(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_copy(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents.test_list = [[1, 2, 3]] @@ -101,7 +103,7 @@ def test_copy(self, fix1_AgentSetPandas: ExampleAgentSet): agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] - def test_discard(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_discard(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with a single value @@ -121,7 +123,7 @@ def test_discard(self, fix1_AgentSetPandas: ExampleAgentSet): result = agents.discard("active", inplace=False) assert result.agents.index.to_list() == [2, 3] - def test_do(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_do(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with no_mask @@ -134,7 +136,7 @@ def test_do(self, fix1_AgentSetPandas: ExampleAgentSet): agents.do("add_wealth", 1, mask=agents["wealth"] > 3) assert agents.agents.wealth.tolist() == [3, 5, 6, 7] - def test_get(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_get(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with a single attribute @@ -150,14 +152,14 @@ def test_get(self, fix1_AgentSetPandas: ExampleAgentSet): selected = agents.select(agents["wealth"] > 1, inplace=False) assert selected.get("wealth", mask="active").tolist() == [2, 3, 4] - def test_remove(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_remove(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents.remove([0, 1]) assert agents.agents.index.tolist() == [2, 3] with pytest.raises(KeyError): agents.remove([1]) - def test_select(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_select(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with default arguments. Should select all agents @@ -195,7 +197,7 @@ def filter_func(agentset: AgentSetPandas) -> pd.Series: selected = agents.select(mask, filter_func=filter_func, n=1, inplace=False) assert any(el in selected.active_agents.index.tolist() for el in [2, 3]) - def test_set(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_set(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with a single attribute @@ -217,7 +219,7 @@ def test_set(self, fix1_AgentSetPandas: ExampleAgentSet): assert agents.agents.wealth.tolist() == [10, 10, 10, 10] assert agents.agents.age.tolist() == [20, 20, 20, 20] - def test_shuffle(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_shuffle(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas for _ in range(10): original_order = agents.agents.index.tolist() @@ -226,13 +228,15 @@ def test_shuffle(self, fix1_AgentSetPandas: ExampleAgentSet): return assert False - def test_sort(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_sort(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents.sort("wealth", ascending=False) assert agents.agents.wealth.tolist() == [4, 3, 2, 1] def test__add__( - self, fix1_AgentSetPandas: ExampleAgentSet, fix2_AgentSetPandas: ExampleAgentSet + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, ): agents = fix1_AgentSetPandas agents2 = fix2_AgentSetPandas @@ -253,13 +257,13 @@ def test__add__( assert agents3.agents.index.tolist() == [0, 1, 2, 3, 10] assert agents3.agents.wealth.tolist() == [1, 2, 3, 4, 5] - def test__contains__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__contains__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): # Test with a single value agents = fix1_AgentSetPandas assert 0 in agents assert 4 not in agents - def test__copy__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__copy__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents.test_list = [[1, 2, 3]] @@ -269,7 +273,7 @@ def test__copy__(self, fix1_AgentSetPandas: ExampleAgentSet): agents2.test_list[0].append(4) assert agents.test_list[0][-1] == agents2.test_list[0][-1] - def test__deepcopy__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__deepcopy__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents.test_list = [[1, 2, 3]] @@ -277,12 +281,12 @@ def test__deepcopy__(self, fix1_AgentSetPandas: ExampleAgentSet): agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] - def test__getattr__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__getattr__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas assert isinstance(agents.model, ModelDF) assert agents.wealth.tolist() == [1, 2, 3, 4] - def test__getitem__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__getitem__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Testing with a string @@ -300,7 +304,9 @@ def test__getitem__(self, fix1_AgentSetPandas: ExampleAgentSet): assert result["age"].values.tolist() == [10] def test__iadd__( - self, fix1_AgentSetPandas: ExampleAgentSet, fix2_AgentSetPandas: ExampleAgentSet + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, ): agents = deepcopy(fix1_AgentSetPandas) agents2 = fix2_AgentSetPandas @@ -324,34 +330,34 @@ def test__iadd__( assert agents.agents.index.tolist() == [0, 1, 2, 3, 10] assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 5] - def test__iter__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__iter__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas - for i, agent in agents: - assert isinstance(agent, pd.Series) - assert agent["wealth"] == i + 1 + for i, agent in enumerate(agents): + assert isinstance(agent, dict) + assert agent["unique_id"] == agents._agents.index[i] - def test__isub__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__isub__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): # Test with an AgentSetPandas and a DataFrame agents = deepcopy(fix1_AgentSetPandas) agents -= agents.agents assert agents.agents.empty - def test__len__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__len__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas assert len(agents) == 4 def test__repr__(self, fix1_AgentSetPandas): - agents: ExampleAgentSet = fix1_AgentSetPandas + agents: ExampleAgentSetPandas = fix1_AgentSetPandas repr(agents) - def test__reversed__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__reversed__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas reversed_wealth = [] for i, agent in reversed(agents): reversed_wealth.append(agent["wealth"]) assert reversed_wealth == [4, 3, 2, 1] - def test__setitem__(self, fix1_AgentSetPandas: ExampleAgentSet): + def test__setitem__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents = deepcopy(agents) # To test passing through a df later @@ -374,23 +380,25 @@ def test__setitem__(self, fix1_AgentSetPandas: ExampleAgentSet): assert agents.agents.loc[0, "wealth"] == 9 assert agents.agents.loc[0, "age"] == 99 - def test__str__(self, fix1_AgentSetPandas: ExampleAgentSet): - agents: ExampleAgentSet = fix1_AgentSetPandas + def test__str__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): + agents: ExampleAgentSetPandas = fix1_AgentSetPandas str(agents) - def test__sub__(self, fix1_AgentSetPandas: ExampleAgentSet): - agents: ExampleAgentSet = fix1_AgentSetPandas - agents2: ExampleAgentSet = agents - agents.agents + def test__sub__(self, fix1_AgentSetPandas: ExampleAgentSetPandas): + agents: ExampleAgentSetPandas = fix1_AgentSetPandas + agents2: ExampleAgentSetPandas = agents - agents.agents assert agents2.agents.empty assert agents.agents.wealth.tolist() == [1, 2, 3, 4] - def test_get_obj(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_get_obj(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas assert agents._get_obj(inplace=True) is agents assert agents._get_obj(inplace=False) is not agents def test_agents( - self, fix1_AgentSetPandas: ExampleAgentSet, fix2_AgentSetPandas: ExampleAgentSet + self, + fix1_AgentSetPandas: ExampleAgentSetPandas, + fix2_AgentSetPandas: ExampleAgentSetPandas, ): agents = fix1_AgentSetPandas agents2 = fix2_AgentSetPandas @@ -400,7 +408,7 @@ def test_agents( agents.agents = agents2.agents assert agents.agents.index.tolist() == [4, 5, 6, 7] - def test_active_agents(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_active_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas # Test with select @@ -411,7 +419,7 @@ def test_active_agents(self, fix1_AgentSetPandas: ExampleAgentSet): agents.active_agents = agents.agents.wealth > 2 assert agents.active_agents.index.to_list() == [2, 3] - def test_inactive_agents(self, fix1_AgentSetPandas: ExampleAgentSet): + def test_inactive_agents(self, fix1_AgentSetPandas: ExampleAgentSetPandas): agents = fix1_AgentSetPandas agents.select(agents["wealth"] > 2, inplace=True) diff --git a/tests/test_agentset_polars.py b/tests/test_agentset_polars.py index 3e54e6d..97a7983 100644 --- a/tests/test_agentset_polars.py +++ b/tests/test_agentset_polars.py @@ -9,7 +9,7 @@ @tg.typechecked -class ExampleAgentSet(AgentSetPolars): +class ExampleAgentSetPolars(AgentSetPolars): def __init__(self, model: ModelDF): super().__init__(model) self.starting_wealth = pl.Series("wealth", [1, 2, 3, 4]) @@ -19,9 +19,9 @@ def add_wealth(self, amount: int) -> None: @pytest.fixture -def fix1_AgentSetPolars() -> ExampleAgentSet: +def fix1_AgentSetPolars() -> ExampleAgentSetPolars: model = ModelDF() - agents = ExampleAgentSet(model) + agents = ExampleAgentSetPolars(model) agents.add({"unique_id": [0, 1, 2, 3]}) agents["wealth"] = agents.starting_wealth agents["age"] = [10, 20, 30, 40] @@ -29,9 +29,9 @@ def fix1_AgentSetPolars() -> ExampleAgentSet: @pytest.fixture -def fix2_AgentSetPolars() -> ExampleAgentSet: +def fix2_AgentSetPolars() -> ExampleAgentSetPolars: model = ModelDF() - agents = ExampleAgentSet(model) + agents = ExampleAgentSetPolars(model) agents.add({"unique_id": [4, 5, 6, 7]}) agents["wealth"] = agents.starting_wealth + 10 agents["age"] = [100, 200, 300, 400] @@ -41,7 +41,7 @@ def fix2_AgentSetPolars() -> ExampleAgentSet: class Test_AgentSetPolars: def test__init__(self): model = ModelDF() - agents = ExampleAgentSet(model) + agents = ExampleAgentSetPolars(model) agents.add({"unique_id": [0, 1, 2, 3]}) assert agents.model == model assert isinstance(agents.agents, pl.DataFrame) @@ -51,7 +51,9 @@ def test__init__(self): assert agents.starting_wealth.to_list() == [1, 2, 3, 4] def test_add( - self, fix1_AgentSetPolars: ExampleAgentSet, fix2_AgentSetPolars: ExampleAgentSet + self, + fix1_AgentSetPolars: ExampleAgentSetPolars, + fix2_AgentSetPolars: ExampleAgentSetPolars, ): agents = fix1_AgentSetPolars agents2 = fix2_AgentSetPolars @@ -72,7 +74,7 @@ def test_add( assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 4, 5] assert agents.agents["age"].to_list() == [10, 20, 30, 40, 50, 60] - def test_contains(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_contains(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with a single value @@ -82,7 +84,7 @@ def test_contains(self, fix1_AgentSetPolars: ExampleAgentSet): # Test with a list assert agents.contains([0, 1]).to_list() == [True, True] - def test_copy(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_copy(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.test_list = [[1, 2, 3]] @@ -96,7 +98,7 @@ def test_copy(self, fix1_AgentSetPolars: ExampleAgentSet): agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] - def test_discard(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_discard(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with a single value @@ -116,7 +118,7 @@ def test_discard(self, fix1_AgentSetPolars: ExampleAgentSet): result = agents.discard("active", inplace=False) assert result.agents["unique_id"].to_list() == [2, 3] - def test_do(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_do(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with no return_results, no mask @@ -131,7 +133,7 @@ def test_do(self, fix1_AgentSetPolars: ExampleAgentSet): agents.do("add_wealth", 1, mask=agents["wealth"] > 3) assert agents.agents["wealth"].to_list() == [3, 5, 6, 7] - def test_get(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_get(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with a single attribute @@ -147,14 +149,14 @@ def test_get(self, fix1_AgentSetPolars: ExampleAgentSet): selected = agents.select(agents.agents["wealth"] > 1, inplace=False) assert selected.get("wealth", mask="active").to_list() == [2, 3, 4] - def test_remove(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_remove(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.remove([0, 1]) assert agents.agents["unique_id"].to_list() == [2, 3] with pytest.raises(KeyError): agents.remove([1]) - def test_select(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_select(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with default arguments. Should select all agents @@ -195,7 +197,7 @@ def filter_func(agentset: AgentSetPolars) -> pl.Series: selected = agents.select(mask, filter_func=filter_func, n=1, inplace=False) assert any(el in selected.active_agents["unique_id"].to_list() for el in [2, 3]) - def test_set(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_set(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with a single attribute @@ -217,7 +219,7 @@ def test_set(self, fix1_AgentSetPolars: ExampleAgentSet): assert agents.agents["wealth"].to_list() == [10, 10, 10, 10] assert agents.agents["age"].to_list() == [20, 20, 20, 20] - def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars for _ in range(10): original_order = agents.agents["unique_id"].to_list() @@ -226,13 +228,15 @@ def test_shuffle(self, fix1_AgentSetPolars: ExampleAgentSet): return assert False - def test_sort(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_sort(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.sort("wealth", ascending=False) assert agents.agents["wealth"].to_list() == [4, 3, 2, 1] def test__add__( - self, fix1_AgentSetPolars: ExampleAgentSet, fix2_AgentSetPolars: ExampleAgentSet + self, + fix1_AgentSetPolars: ExampleAgentSetPolars, + fix2_AgentSetPolars: ExampleAgentSetPolars, ): agents = fix1_AgentSetPolars agents2 = fix2_AgentSetPolars @@ -253,13 +257,13 @@ def test__add__( assert agents3.agents["unique_id"].to_list() == [0, 1, 2, 3, 10] assert agents3.agents["wealth"].to_list() == [1, 2, 3, 4, 5] - def test__contains__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__contains__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with a single value agents = fix1_AgentSetPolars assert 0 in agents assert 4 not in agents - def test__copy__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__copy__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.test_list = [[1, 2, 3]] @@ -268,7 +272,7 @@ def test__copy__(self, fix1_AgentSetPolars: ExampleAgentSet): agents2.test_list[0].append(4) assert agents.test_list[0][-1] == agents2.test_list[0][-1] - def test__deepcopy__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__deepcopy__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.test_list = [[1, 2, 3]] @@ -276,12 +280,12 @@ def test__deepcopy__(self, fix1_AgentSetPolars: ExampleAgentSet): agents2.test_list[0].append(4) assert agents.test_list[-1] != agents2.test_list[-1] - def test__getattr__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__getattr__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars assert isinstance(agents.model, ModelDF) assert agents.wealth.to_list() == [1, 2, 3, 4] - def test__getitem__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__getitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Testing with a string @@ -299,7 +303,9 @@ def test__getitem__(self, fix1_AgentSetPolars: ExampleAgentSet): assert result["age"].to_list() == [10] def test__iadd__( - self, fix1_AgentSetPolars: ExampleAgentSet, fix2_AgentSetPolars: ExampleAgentSet + self, + fix1_AgentSetPolars: ExampleAgentSetPolars, + fix2_AgentSetPolars: ExampleAgentSetPolars, ): agents = deepcopy(fix1_AgentSetPolars) agents2 = fix2_AgentSetPolars @@ -323,34 +329,34 @@ def test__iadd__( assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3, 10] assert agents.agents["wealth"].to_list() == [1, 2, 3, 4, 5] - def test__iter__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__iter__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars for i, agent in enumerate(agents): assert isinstance(agent, dict) assert agent["wealth"] == i + 1 - def test__isub__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__isub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): # Test with an AgentSetPolars and a DataFrame agents = deepcopy(fix1_AgentSetPolars) agents -= agents.agents assert agents.agents.is_empty() - def test__len__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__len__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars assert len(agents) == 4 def test__repr__(self, fix1_AgentSetPolars): - agents: ExampleAgentSet = fix1_AgentSetPolars + agents: ExampleAgentSetPolars = fix1_AgentSetPolars repr(agents) - def test__reversed__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__reversed__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars reversed_wealth = [] for i, agent in reversed(list(enumerate(agents))): reversed_wealth.append(agent["wealth"]) assert reversed_wealth == [4, 3, 2, 1] - def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSet): + def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents = deepcopy(agents) # To test passing through a df later @@ -373,23 +379,25 @@ def test__setitem__(self, fix1_AgentSetPolars: ExampleAgentSet): assert agents.agents.item(0, "wealth") == 9 assert agents.agents.item(0, "age") == 99 - def test__str__(self, fix1_AgentSetPolars: ExampleAgentSet): - agents: ExampleAgentSet = fix1_AgentSetPolars + def test__str__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): + agents: ExampleAgentSetPolars = fix1_AgentSetPolars str(agents) - def test__sub__(self, fix1_AgentSetPolars: ExampleAgentSet): - agents: ExampleAgentSet = fix1_AgentSetPolars - agents2: ExampleAgentSet = agents - agents.agents + def test__sub__(self, fix1_AgentSetPolars: ExampleAgentSetPolars): + agents: ExampleAgentSetPolars = fix1_AgentSetPolars + agents2: ExampleAgentSetPolars = agents - agents.agents assert agents2.agents.is_empty() assert agents.agents["wealth"].to_list() == [1, 2, 3, 4] - def test_get_obj(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_get_obj(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars assert agents._get_obj(inplace=True) is agents assert agents._get_obj(inplace=False) is not agents def test_agents( - self, fix1_AgentSetPolars: ExampleAgentSet, fix2_AgentSetPolars: ExampleAgentSet + self, + fix1_AgentSetPolars: ExampleAgentSetPolars, + fix2_AgentSetPolars: ExampleAgentSetPolars, ): agents = fix1_AgentSetPolars agents2 = fix2_AgentSetPolars @@ -399,7 +407,7 @@ def test_agents( agents.agents = agents2.agents assert agents.agents["unique_id"].to_list() == [4, 5, 6, 7] - def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars # Test with select @@ -410,7 +418,7 @@ def test_active_agents(self, fix1_AgentSetPolars: ExampleAgentSet): agents.active_agents = agents.agents["wealth"] > 2 assert agents.active_agents["unique_id"].to_list() == [2, 3] - def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSet): + def test_inactive_agents(self, fix1_AgentSetPolars: ExampleAgentSetPolars): agents = fix1_AgentSetPolars agents.select(agents.agents["wealth"] > 2, inplace=True)