From 4e61b0d305f981be66f45a2b8783dd190ef5d2d2 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:22:29 +0200 Subject: [PATCH 1/2] dunder methods for agentsetpandas, new API, construct an abstract class --- mesa_frames/__init__.py | 4 +- mesa_frames/agent.py | 1201 +++++++++++++++++++++++++++++++-------- 2 files changed, 953 insertions(+), 252 deletions(-) diff --git a/mesa_frames/__init__.py b/mesa_frames/__init__.py index a80f01a..c48f87d 100644 --- a/mesa_frames/__init__.py +++ b/mesa_frames/__init__.py @@ -1,2 +1,2 @@ -from .agent import AgentsDF, AgentSetDF -from .model import ModelDF +from mesa_frames.agent import AgentSetPandas, AgentSetPolars, AgentsPandas, AgentsPolars +from mesa_frames.model import ModelDF diff --git a/mesa_frames/agent.py b/mesa_frames/agent.py index cfa6e61..1084f79 100644 --- a/mesa_frames/agent.py +++ b/mesa_frames/agent.py @@ -1,5 +1,8 @@ +from __future__ import annotations # PEP 563: postponed evaluation of type annotations + from abc import ABC, abstractmethod from contextlib import suppress +from copy import copy, deepcopy from typing import ( TYPE_CHECKING, Any, @@ -14,28 +17,49 @@ import pandas as pd import polars as pl +from numpy import int64, ndarray +from pandas.core.arrays.base import ExtensionArray +from polars.datatypes import N_INFER_DEFAULT from mesa_frames.model import ModelDF -if TYPE_CHECKING: +# For AgentSetPandas.select +ArrayLike = ExtensionArray | ndarray +AnyArrayLike = ArrayLike | pd.Index | pd.Series +ListLike = AnyArrayLike | list | range - # For AgentSetDF - from numpy.random import Generator +# For AgentSetPandas.drop +IndexLabel = Hashable | Sequence[Hashable] - DataFrameLike = pd.DataFrame | pl.DataFrame - MaskLike = pd.Series[bool] | pl.Expr | pl.Series +# For AgentContainer.__getitem__ and AgentContainer.__setitem__ +DataFrame = pd.DataFrame | pl.DataFrame - # For AgentSetPandas - from numpy import ndarray - from pandas.core.arrays.base import ExtensionArray +Series = pd.Series | pl.Series - from .model import ModelDF +BoolSeries = pd.Series | pl.Expr | pl.Series + +PandasMaskLike = ( + Literal["active"] | Literal["all"] | pd.Series | pd.DataFrame | ListLike | Hashable +) + +PolarsMaskLike = ( + Literal["active"] + | Literal["all"] + | pl.Expr + | pl.Series + | pl.DataFrame + | ListLike + | Hashable +) + +MaskLike = PandasMaskLike | PolarsMaskLike + +if TYPE_CHECKING: + + # For AgentSetDF + from numpy.random import Generator - ArrayLike = ExtensionArray | ndarray - AnyArrayLike = ArrayLike | pd.Index | pd.Series ValueKeyFunc = Callable[[pd.Series], pd.Series | AnyArrayLike] | None - from pandas._typing import Axes, Dtype, ListLikeU - from polars.datatypes import N_INFER_DEFAULT # For AgentSetPolars from polars.type_aliases import ( @@ -52,106 +76,492 @@ class AgentContainer(ABC): model: ModelDF - _mask: MaskLike - """An abstract class for containing agents. Defines the common interface for AgentSetDF and AgentsDF.""" - - def __init__(self, model: ModelDF) -> None: - self.model = model + _mask: BoolSeries + _skip_copy: list[str] = ["model", "_mask"] + """An abstract class for containing agents. Defines the common interface for AgentSetDF and AgentsDF. - def __get_item__(self, attr_name: str) -> Any: - return self.get_attribute(attr_name) - - def __set_item__(self, attr_name: str, value: Any) -> None: - self.set_attribute(attr_name, value) + Attributes + ---------- + model : ModelDF + The model to which the AgentContainer belongs. + _mask : Series + A boolean mask indicating which agents are active. + _skip_copy : list[str] + A list of attributes to skip during the copy process. + """ + + def __new__(cls, model: ModelDF) -> Self: + """Create a new AgentContainer object. + + Parameters + ---------- + model : ModelDF + The model to which the AgentContainer belongs. + + Returns + ------- + Self + A new AgentContainer object. + """ + obj = super().__new__(cls) + obj.model = model + return obj + + def __add__(self, other: Self | DataFrame | ListLike | dict[str, Any]) -> Self: + """Add agents to a new AgentContainer through the + operator. + + Other can be: + - A Self: adds the agents from the other AgentContainer. + - A DataFrame: adds the agents from the DataFrame. + - A ListLike: should be one single agent to add. + - A dictionary: keys should be attributes and values should be the values to add. + + Parameters + ---------- + other : Self | DataFrame | ListLike | dict[str, Any] + The agents to add. + + Returns + ------- + Self + A new AgentContainer with the added agents. + """ + new_obj = deepcopy(self) + return new_obj.add(other) + + def __iadd__(self, other: Self | DataFrame | ListLike | dict[str, Any]) -> Self: + """Add agents to the AgentContainer through the += operator. + + Other can be: + - A Self: adds the agents from the other AgentContainer. + - A DataFrame: adds the agents from the DataFrame. + - A ListLike: should be one single agent to add. + - A dictionary: keys should be attributes and values should be the values to add. + + Parameters + ---------- + other : Self | DataFrame | ListLike | dict[str, Any] + The agents to add. + + Returns + ------- + Self + The updated AgentContainer. + """ + return self.add(other) + + @abstractmethod + def __contains__(self, id: Hashable) -> bool: + """Check if an agent is in the AgentContainer. + + Parameters + ---------- + id : Hashable + The ID(s) to check for. + + Returns + ------- + bool + True if the agent is in the AgentContainer, False otherwise. + """ + + def __copy__(self) -> Self: + """Create a shallow copy of the AgentContainer. + + Returns + ------- + Self + A shallow copy of the AgentContainer. + """ + return self.copy(deep=False) + + def __deepcopy__(self, memo: dict) -> Self: + """Create a deep copy of the AgentContainer. + + Parameters + ---------- + memo : dict + A dictionary to store the copied objects. + + Returns + ------- + Self + A deep copy of the AgentContainer. + """ + return self.copy(deep=True, memo=memo) + + def __getattr__(self, name: str) -> Series: + """Fallback for retrieving attributes of the AgentContainer. Retrieves an attribute column of the agents in the AgentContainer. + + Parameters + ---------- + name : str + The name of the attribute to retrieve. + + Returns + ------- + Series + The attribute values. + """ + return self.get_attribute(name) + + @overload + def __getitem__(self, key: str | tuple[MaskLike, str]) -> Series: ... + + @overload + def __getitem__(self, key: list[str]) -> DataFrame: ... + + def __getitem__( + self, key: str | list[str] | MaskLike | tuple[MaskLike, str | list[str]] + ) -> Series | DataFrame: # tuple is not generic so it is not type hintable + """Implement the [] operator for the AgentContainer. + + The key can be: + - A string (eg. AgentContainer["str"]): returns the specified column of the agents in the AgentContainer. + - A list of strings(eg. AgentContainer[["str1", "str2"]]): returns the specified columns of the agents in the AgentContainer. + - A tuple (eg. AgentContainer[mask, "str"]): returns the specified column of the agents in the AgentContainer that satisfy the mask. + - A mask (eg. AgentContainer[mask]): returns the agents in the AgentContainer that satisfy the mask. + + Parameters + ---------- + key : str | list[str] | MaskLike | tuple[MaskLike, str | list[str]] + The key to retrieve. + + Returns + ------- + Series | DataFrame + The attribute values. + """ + if isinstance(key, (str, list)): + return self.get_attribute(attr_names=key) + + elif isinstance(key, tuple): + return self.get_attribute(mask=key[0], attr_names=key[1]) + + else: # MaskLike + return self.get_attribute(mask=key) + + @abstractmethod + def __iter__(self) -> Iterable: + """Iterate over the agents in the AgentContainer. + + Returns + ------- + Iterable + An iterator over the agents. + """ + + def __isub__(self, other: MaskLike) -> Self: + """Remove agents from the AgentContainer through the -= operator. + + Parameters + ---------- + other : Self | DataFrame | ListLike + The agents to remove. + + Returns + ------- + Self + The updated AgentContainer. + """ + return self.discard(other) + + @abstractmethod + def __len__(self) -> int | dict[str, int]: + """Get the number of agents in the AgentContainer. + + Returns + ------- + int | dict[str, int] + The number of agents in the AgentContainer. + """ + + @abstractmethod + def __repr__(self) -> str: + """Get a string representation of the DataFrame in the AgentContainer. + + Returns + ------- + str + A string representation of the DataFrame in the AgentContainer. + """ + return repr(self.agents) + + def __setitem__( + self, + key: str | list[str] | MaskLike | tuple[MaskLike, str | list[str]], + value: Any, + ) -> None: + """Implement the [] operator for setting values in the AgentContainer. + + The key can be: + - A string (eg. AgentContainer["str"]): sets the specified column of the agents in the AgentContainer. + - A list of strings(eg. AgentContainer[["str1", "str2"]]): sets the specified columns of the agents in the AgentContainer. + - A tuple (eg. AgentContainer[mask, "str"]): sets the specified column of the agents in the AgentContainer that satisfy the mask. + - A mask (eg. AgentContainer[mask]): sets the attributes of the agents in the AgentContainer that satisfy the mask. + + Parameters + ---------- + key : str | list[str] | MaskLike | tuple[MaskLike, str | list[str] + The key to set. + """ + if isinstance(key, (str, list)): + self.set_attribute(attr_names=key, value=value) + + elif isinstance(key, tuple): + self.set_attribute(mask=key[0], attr_names=key[1], value=value) + + else: # key=MaskLike + self.set_attribute(mask=key, value=value) + + @abstractmethod + def __str__(self) -> str: + """Get a string representation of the DataFrame in the AgentContainer. + + Returns + ------- + str + A string representation of the DataFrame in the AgentContainer. + """ + + def __sub__(self, other: MaskLike) -> Self: + """Remove agents from a new AgentContainer through the - operator. + + Parameters + ---------- + other : DataFrame | ListLike + The agents to remove. + + Returns + ------- + Self + A new AgentContainer with the removed agents. + """ + new_obj = deepcopy(self) + return new_obj.discard(other) + + @abstractmethod + def __reversed__(self) -> Iterable: + """Iterate over the agents in the AgentContainer in reverse order. + + Returns + ------- + Iterable + An iterator over the agents in reverse order. + """ @property @abstractmethod - def active_agents(self) -> DataFrameLike: - """The active agents in the AgentContainer (those that are used for the do, set_attribute, get_attribute operations). + def agents(self) -> DataFrame: + """The agents in the AgentContainer. Returns ------- - DataFrameLike + DataFrame """ - pass - + + @property + @abstractmethod + def active_agents(self) -> DataFrame: + """The active agents in the AgentContainer. + + Returns + ------- + DataFrame + """ + @active_agents.setter - def active_agents(self, agents: DataFrameLike | MaskLike) -> None: - self.select(mask=agents) + def active_agents(self, mask: MaskLike) -> None: + """Set the active agents in the AgentContainer. + + Parameters + ---------- + mask : MaskLike + The mask to apply. + """ + self.select(mask=mask) @property @abstractmethod - def inactive_agents(self) -> DataFrameLike: - """The inactive agents in the AgentContainer (those that are not used for the do, set_attribute, get_attribute operations). + def inactive_agents(self) -> DataFrame: + """The inactive agents in the AgentContainer. Returns ------- - DataFrameLike + DataFrame """ - pass - @inactive_agents.setter - def inactive_agents(self, agents: DataFrameLike | MaskLike) -> None: - self.select(mask=agents) - @property def random(self) -> Generator: """ Provide access to the model's random number generator. - Returns: - ---------- + Returns + ------- np.Generator """ return self.model.random + def _get_obj(self, inplace: bool) -> Self: + """Get the object to perform operations on. + + Parameters + ---------- + inplace : bool + If inplace, return self. Otherwise, return a copy. + + Returns + ---------- + Self + The object to perform operations on. + """ + if inplace: + return self + else: + return deepcopy(self) + @abstractmethod - def select( + def contains(self, ids: MaskLike) -> BoolSeries: + """Check if agents with the specified IDs are in the AgentContainer. + + Parameters + ---------- + id : MaskLike + The ID(s) to check for. + + Returns + ------- + BoolSeries + """ + + def copy( self, - mask: MaskLike | DataFrameLike | None = None, - filter_func: Callable[[Self], MaskLike] | None = None, - n: int = 0, + deep: bool = False, + skip: list[str] | str | None = None, + memo: dict | None = None, ) -> Self: - """ - Selects a subset of agents based on the given criteria. + """Create a copy of the AgentContainer. - Parameters: + Parameters ---------- - mask (MaskLike | DataFrameLike | None): A boolean mask or DataFrame used to filter the agents. - filter_func (Callable[[AgentContainer], MaskLike] | None): A function that takes an AgentContainer and returns a boolean mask. - n (int): The maximum number of agents to select. + deep : bool, optional + Flag indicating whether to perform a deep copy of the AgentContainer. + If True, all attributes of the AgentContainer will be recursively copied (except self.agents, check Pandas/Polars documentation). + If False, only the top-level attributes will be copied. + Defaults to False. + + skip : list[str] | str | None, optional + A list of attribute names or a single attribute name to skip during the copy process. + If an attribute name is specified, it will be skipped for all levels of the copy. + If a list of attribute names is specified, they will be skipped for all levels of the copy. + If None, no attributes will be skipped. + Defaults to None. + + memo : dict | None, optional + A dictionary used to track already copied objects during deep copy. + Defaults to None. - Returns: + Returns + ------- + Self + A new instance of the AgentContainer class that is a copy of the original instance. + """ + skip_list = self._skip_copy.copy() + cls = self.__class__ + obj = cls.__new__(cls, self.model) + if isinstance(skip, str): + skip_list.append(skip) + elif isinstance(skip, list): + skip_list += skip + if deep: + if not memo: + memo = {} + memo[id(self)] = obj + attributes = self.__dict__.copy() + setattr(obj, "model", attributes.pop("model")) + [ + setattr(obj, k, deepcopy(v, memo)) + for k, v in attributes.items() + if k not in skip_list + ] + else: + [ + setattr(obj, k, copy(v)) + for k, v in self.__dict__.items() + if k not in skip_list + ] + return obj + + @abstractmethod + def select( + self, + mask: MaskLike | None = None, + filter_func: Callable[[Self], MaskLike] | None = None, + n: int | None = None, + inplace: bool = True, + ) -> Self: + """Select agents in the AgentContainer based on the given criteria. + + Parameters ---------- - AgentContainer: Returns an AgentContainer with selected agents as active. + mask : MaskLike | None, optional + The mask of agents to be selected, by default None + filter_func : Callable[[Self], MaskLike] | None, optional + A function which takes as input the AgentContainer and returns a MaskLike, by default None + n : int, optional + The maximum number of agents to be selected, by default None + inplace : bool, optional + If the operation should be performed on the same object, by default True + Returns + ------- + Self + A new or updated AgentContainer. """ - pass @abstractmethod - def shuffle(self) -> Self: + def shuffle(self, inplace: bool = True) -> Self: """ Shuffles the order of agents in the AgentContainer. - Returns: + Parameters ---------- - AgentContainer: The shuffled agent set. + inplace : bool + Whether to shuffle the agents in place. + + Returns + ---------- + Self + A new or updated AgentContainer. """ - pass @abstractmethod - def sort(self, *args, **kwargs) -> Self: + def sort(self, *args, inplace: bool = True, **kwargs) -> Self: """ Sorts the agents in the agent set based on the given criteria. + + Parameters + ---------- + *args + Positional arguments to pass to the sort method. + inplace : bool + Whether to sort the agents in place. + **kwargs + Keyword arguments to pass to the sort + + Returns + ---------- + Self + A new or updated AgentContainer. """ - pass @overload def do( self, method_name: str, - return_results: Literal[False] = False, *args, + return_results: Literal[False] = False, + inplace: bool = True, **kwargs, ) -> Self: ... @@ -159,266 +569,558 @@ def do( def do( self, method_name: str, - return_results: Literal[True], *args, + return_results: Literal[True], + inplace: bool = True, **kwargs, ) -> Any: ... def do( self, method_name: str, - return_results: bool = False, *args, + return_results: bool = False, + inplace: bool = True, **kwargs, ) -> Self | Any: - """ - Invokes a method on the AgentContainer. + """Invoke a method on the AgentContainer. - Parameters: + Parameters ---------- - method_name (str): The name of the method to call. - return_results (bool): Whether to return the results of the method call. - *args: Positional arguments to pass to the method. - **kwargs: Keyword arguments to pass to the method. + method_name : str + The name of the method to invoke. + return_results : bool, optional + Whether to return the result of the method, by default False + inplace : bool, optional + Whether the operation should be done inplace, by default True - Returns: - ---------- - AgentContainer | Any: The updated agent set or the results of the method call. + Returns + ------- + Self | Any + The updated AgentContainer or the result of the method. """ - method = getattr(self, method_name) + obj = self._get_obj(inplace) + method = getattr(obj, method_name) if return_results: return method(*args, **kwargs) else: method(*args, **kwargs) - return self + return obj @abstractmethod - def get_attribute(self, attr_name: str) -> MaskLike: + @overload + def get_attribute( + self, + attr_names: list[str] | None = None, + mask: MaskLike | None = None, + ) -> DataFrame: ... + + @abstractmethod + @overload + def get_attribute( + self, + attr_names: str, + mask: MaskLike | None = None, + ) -> Series: ... + + @abstractmethod + def get_attribute( + self, + attr_names: str | list[str] | None = None, + mask: MaskLike | None = None, + ) -> Series | DataFrame: """ Retrieves the value of a specified attribute for each agent in the AgentContainer. - Parameters: + Parameters ---------- - attr_name (str): The name of the attribute to retrieve. + attr_names : str | list[str] | None + The name of the attribute to retrieve. If None, all attributes are retrieved. Defaults to None. + mask : MaskLike | None + The mask of agents to retrieve the attribute for. If None, attributes of all agents are returned. Defaults to None. - Returns: + Returns ---------- - MaskLike: The attribute values. + Series | DataFrame + The attribute values. """ - pass @abstractmethod - def set_attribute(self, attr_name: str, value: Any) -> Self: + @overload + def set_attribute( + self, + attr_names: None = None, + value: Any = Any, + mask: MaskLike = MaskLike, + inplace: bool = True, + ) -> Self: ... + + @abstractmethod + @overload + def set_attribute( + self, + attr_names: dict[str, Any], + value: None, + mask: MaskLike | None = None, + inplace: bool = True, + ) -> Self: ... + + @abstractmethod + @overload + def set_attribute( + self, + attr_names: str | list[str], + value: Any, + mask: MaskLike | None = None, + inplace: bool = True, + ) -> Self: ... + + @abstractmethod + def set_attribute( + self, + attr_names: str | dict[str, Any] | list[str] | None = None, + value: Any | None = None, + mask: MaskLike | None = None, + inplace: bool = True, + ) -> Self: """ - Sets the value of a specified attribute for each agent in the AgentContainer. + Sets the value of a specified attribute or attributes for each agent in the AgentContainer. + + The key can be: + - A string: sets the specified column of the agents in the AgentContainer. + - A list of strings: sets the specified columns of the agents in the AgentContainer. + - A dictionary: keys should be attributes and values should be the values to set. Value should be None. - Parameters: + Parameters ---------- - attr_name (str): The name of the attribute to set. - value (Any): The value to set the attribute to. + attr_names : str | dict[str, Any] + The name of the attribute to set. + value : Any | None + The value to set the attribute to. If None, attr_names must be a dictionary. + mask : MaskLike | None + The mask of agents to set the attribute for. + inplace : bool + Whether to set the attribute in place. - Returns: + Returns ---------- - AgentContainer: The updated agent set. + AgentContainer + The updated agent set. """ - pass @abstractmethod - def add(self, n: int, *args, **kwargs) -> Self: - """Adds new agents to the AgentContainer.""" - pass + def add( + self, other: Self | DataFrame | ListLike | dict[str, Any], inplace: bool = True + ) -> Self: + """Adds agents to the AgentContainer. - @abstractmethod - def discard(self, id: int) -> Self: + Other can be: + - A Self: adds the agents from the other AgentContainer. + - A DataFrame: adds the agents from the DataFrame. + - A ListLike: should be one single agent to add. + - A dictionary: keys should be attributes and values should be the values to add. + + Parameters + ---------- + other : Self | DataFrame | ListLike | dict[str, Any] + The agents to add. + inplace : bool, optional + Whether the operation is done into place, by default True + + Returns + ------- + Self + The updated AgentContainer. """ - Removes an agent from the AgentContainer. - Parameters: + def discard(self, id: MaskLike, inplace: bool = True) -> Self: + """ + Removes an agent from the AgentContainer. Does not raise an error if the agent is not found. + + Parameters ---------- - id (int): The ID of the agent to remove. + id : ListLike | Any + The ID of the agent to remove. + inplace : bool + Whether to remove the agent in place. - Returns: + Returns ---------- - AgentContainer: The updated AgentContainer. + AgentContainer + The updated AgentContainer. """ - pass + with suppress(KeyError): + return self.remove(id, inplace=inplace) @abstractmethod - def remove(self, id: int) -> Self: + def remove(self, id: MaskLike, inplace: bool = True) -> Self: """ Removes an agent from the AgentContainer. - Parameters: + Parameters ---------- - id (int): The ID of the agent to remove. + id : ListLike | Any + The ID of the agent to remove. + inplace : bool + Whether to remove the agent in place. - Returns: + Returns ---------- - AgentContainer: The updated AgentContainer. + AgentContainer + The updated AgentContainer. """ - pass ### The AgentSetDF class is a container for agents of the same type. It has an implementation with Pandas and Polars ### class AgentSetDF(AgentContainer): - agents: DataFrameLike - """An abstract class for a set of agents of the same type.""" - + _agents: DataFrame + _skip_copy = ["model", "_mask", "_agents"] + """A container for agents of the same type. + + Attributes + ---------- + model : ModelDF + The model to which the AgentSetDF belongs. + _mask : Series + A boolean mask indicating which agents are active. + _agents : DataFrame + The agents in the AgentSetDF. + _skip_copy : list[str] + A list of attributes to skip during the copy process. + """ -class AgentSetPandas(AgentSetDF): - agents: pd.DataFrame - _mask: pd.Series[bool] - """A pandas-based implementation of the AgentSet.""" + @property + def agents(self) -> DataFrame: + """The agents in the AgentSetDF.""" + return self._agents - def __init__(self, model: ModelDF): - """Create a new AgentSetDF. + @agents.setter + def agents_setter(self, agents: DataFrame) -> None: + """Set the agents in the AgentSetDF. Parameters ---------- - model : ModelDF - The model to which the AgentSetDF belongs. + agents : DataFrame + The agents to set. + """ + self._agents = agents - Attributes + def __len__(self) -> int: + return len(self._agents) + + def __repr__(self) -> str: + return repr(self._agents) + + def __str__(self) -> str: + return str(self._agents) + + def contains(self, ids: MaskLike) -> BoolSeries | bool: + + if isinstance( + ids, + (ListLike, Series, pl.Expr, DataFrame), + ) or ids == "all" or ids == "active": + return self._get_bool_mask(ids) + else: + return ids in self + + @abstractmethod + def _get_bool_mask(self, mask: MaskLike) -> BoolSeries: + """Get a boolean mask for the agents in the AgentSet. + + The mask can be: + - "all": all agents are selected. + - "active": only active agents are selected. + - A ListLike of IDs: only agents with the specified IDs are selected. + - A DataFrame: only agents with indices in the DataFrame are selected. + - A BoolSeries: only agents with True values are selected. + - Any other value: only the agent with the specified ID value is selected. + + Parameters ---------- - agents : pd.DataFrame - The agents in the AgentSetDF. - model : ModelDF - The model to which the AgentSetDF belongs. + mask : MaskLike + The mask to apply. + + Returns + ------- + BoolSeries + The boolean mask for the agents. """ - super().__init__(model) + + +class AgentSetPandas(AgentSetDF): + _agents: pd.DataFrame + _mask: pd.Series[bool] + """A pandas-based implementation of the AgentSet. + + Attributes + ---------- + model : ModelDF + The model to which the AgentSet belongs. + _mask : pd.Series[bool] + A boolean mask indicating which agents are active. + _agents : pd.DataFrame + The agents in the AgentSet. + _skip_copy : list[str] + A list of attributes to skip during the copy process. + """ + + def __new__(cls, model: ModelDF) -> Self: + obj = super().__new__(cls, model) + obj._agents = pd.DataFrame(columns=["unique_id"]).set_index("unique_id") + obj._mask = pd.Series(True, index=obj._agents.index) + return obj + + def __contains__(self, id: Hashable) -> bool: + return id in self._agents.index + + def __deepcopy__(self, memo: dict) -> Self: + obj = super().__deepcopy__(memo) + obj._agents = self._agents.copy(deep=True) + return obj + + def __iter__(self): + return self._agents.iterrows() + + def __reversed__(self) -> Iterable: + return self._agents[::-1].iterrows() + + @property + def agents(self) -> pd.DataFrame: + return self._agents @property def active_agents(self) -> pd.DataFrame: - return self.agents.loc[self._mask] + return self._agents.loc[self._mask] + + @active_agents.setter # When a property is overriden, so it is the getter + def active_agents(self, mask: PandasMaskLike) -> None: + return AgentContainer.active_agents.fset(self, mask) # type: ignore @property def inactive_agents(self) -> pd.DataFrame: - return self.agents.loc[~self._mask] + return self._agents.loc[~self._mask] + + def _get_bool_mask( + self, + mask: PandasMaskLike | None = None, + ) -> pd.Series: + if isinstance(mask, pd.Series) and mask.dtype == bool: + return mask + elif isinstance(mask, self.__class__): + return pd.Series( + self._agents.index.isin(mask.agents.index), index=self._agents.index + ) + elif isinstance(mask, pd.DataFrame): + return pd.Series( + self._agents.index.isin(mask.index), index=self._agents.index + ) + elif isinstance(mask, list): + return pd.Series(self._agents.index.isin(mask), index=self._agents.index) + elif mask is None or mask == "all": + return pd.Series(True, index=self._agents.index) + elif mask == "active": + return self._mask + else: + return pd.Series(self._agents.index.isin([mask]), index=self._agents.index) + + def copy( + self, + deep: bool = False, + skip: list[str] | str | None = None, + memo: dict | None = None, + ) -> Self: + obj = super().copy(deep, skip, memo) + obj._agents = self._agents.copy(deep=deep) + obj._mask = self._mask.copy(deep=deep) + return obj def select( self, - mask: pd.Series[bool] | pd.DataFrame | None = None, - filter_func: Callable[[Self], pd.Series[bool]] | None = None, - n: int = 0, + mask: PandasMaskLike | None = None, + filter_func: Callable[[Self], PandasMaskLike] | None = None, + n: int | None = None, + inplace: bool = True, ) -> Self: - if mask is None: - mask = pd.Series(True, index=self.agents.index) - elif isinstance(mask, pd.DataFrame): - mask = pd.Series( - self.agents.index.isin(mask.index), index=self.agents.index + obj = self._get_obj(inplace) + bool_mask = obj._get_bool_mask(mask) + if n != None: + bool_mask = pd.Series( + obj._agents.index.isin(obj._agents[bool_mask].sample(n).index), + index=obj._agents.index, ) if filter_func: - mask = mask & filter_func(self) - if n != 0: - mask = pd.Series(self.agents[mask].sample(n).index.isin(self.agents.index)) - self._mask = mask - return self + bool_mask = bool_mask & obj._get_bool_mask(filter_func(obj)) + obj._mask = bool_mask + return obj - def shuffle(self) -> Self: - self.agents = self.agents.sample(frac=1) - return self + def shuffle(self, inplace: bool = True) -> Self: + obj = self._get_obj(inplace) + obj._agents = obj._agents.sample(frac=1) + return obj def sort( self, by: str | Sequence[str], - key: ValueKeyFunc | None, + key: ValueKeyFunc | None = None, ascending: bool | Sequence[bool] = True, + inplace: bool = True, ) -> Self: """ Sort the agents in the agent set based on the given criteria. - Parameters: + Parameters ---------- - by : str | Sequence[str] - The attribute(s) to sort by. - key : ValueKeyFunc | None - A function to use for sorting. - ascending : bool | Sequence[bool] - Whether to sort in ascending order. + by : str | Sequence[str] + The attribute(s) to sort by. + key : ValueKeyFunc | None + A function to use for sorting. + ascending : bool | Sequence[bool] + Whether to sort in ascending order. - Returns: + Returns ---------- - AgentSetDF: The sorted agent set. + AgentSetDF: The sorted agent set. """ - self.agents.sort_values(by=by, key=key, ascending=ascending, inplace=True) - return self - - def get_attribute(self, attr_name: str) -> pd.Series[Any]: - return self.agents.loc[ - self.agents.index.isin(self.active_agents.index), attr_name - ] + obj = self._get_obj(inplace) + obj._agents.sort_values(by=by, key=key, ascending=ascending, inplace=True) + return obj - def set_attribute(self, attr_name: str, value: Any) -> Self: - self.agents.loc[self.agents.index.isin(self.active_agents.index), attr_name] = ( - value - ) - return self + @overload + def set_attribute( + self, + attr_names: None = None, + value: Any = Any, + mask: PandasMaskLike = PandasMaskLike, + inplace: bool = True, + ) -> Self: ... - def add( + @overload + def set_attribute( self, - n: int, - data: ( - ListLikeU - | pd.DataFrame - | dict[Any, Any] - | Iterable[ListLikeU | tuple[Hashable, ListLikeU] | dict[Any, Any]] - | None - ) = None, - index: Axes | None = None, - copy: bool = False, - columns: Axes | None = None, - dtype: Dtype | None = None, - ) -> Self: - """ - Adds new agents to the agent set. + attr_names: dict[str, Any], + value: None, + mask: PandasMaskLike | None = None, + inplace: bool = True, + ) -> Self: ... - Parameters: - ---------- - n (int): The number of agents to add. - data (ListLikeU | pd.DataFrame | dict[Any, Any] | Iterable[ListLikeU | tuple[Hashable, ListLikeU] | dict[Any, Any]] | None): The data for the new agents. - index (Axes | None): The index for the new agents. - copy (bool): Whether to copy the data. - columns (Axes | None): The columns for the new agents. - dtype (Dtype | None): The data type for the new agents. + @overload + def set_attribute( + self, + attr_names: str | list[str], + value: Any, + mask: PandasMaskLike | None = None, + inplace: bool = True, + ) -> Self: ... - Returns: - ---------- - AgentSetDF: The updated agent set. - """ - if not index: - index = pd.Index((self.random.random(n) * 10**8).astype(int)) + def set_attribute( + self, + attr_names: str | list[str] | dict[str, Any] | None = None, + value: Any | None = None, + mask: PandasMaskLike | None = None, + inplace: bool = True, + ) -> Self: + obj = self._get_obj(inplace) + mask = obj._get_bool_mask(mask) + if attr_names is None: + attr_names = obj._agents.columns.values.tolist() + if isinstance(attr_names, (str, list)) and value is not None: + obj._agents.loc[mask, attr_names] = value + elif isinstance(attr_names, dict): + for key, value in attr_names.items(): + obj._agents.loc[mask, key] = value + else: + raise ValueError( + "attr_names must be a string or a dictionary with columns as keys and values." + ) + return obj - new_df = pd.DataFrame( - data=data, - index=index, - columns=columns, - dtype=dtype, - copy=copy, - ) + @overload + def get_attribute( + self, + attr_names: list[str] | None = None, + mask: PandasMaskLike | None = None, + ) -> pd.DataFrame: ... - self.agents = pd.concat([self.agents, new_df]) + @overload + def get_attribute( + self, + attr_names: str, + mask: PandasMaskLike | None = None, + ) -> pd.Series: ... - if self._mask.empty: - self._mask = pd.Series(True, index=new_df.index) + def get_attribute( + self, + attr_names: str | list[str] | None = None, + mask: PandasMaskLike | None = None, + inplace: bool = True, + ) -> pd.Series | pd.DataFrame: + obj = self._get_obj(inplace) + mask = obj._get_bool_mask(mask) + if attr_names is None: + return obj._agents.loc[mask] else: - self._mask = pd.concat([self._mask, pd.Series(True, index=new_df.index)]) - - return self + return obj._agents.loc[mask, attr_names] - def discard(self, id: int) -> Self: - with suppress(KeyError): - self.agents.drop(id, inplace=True) - return self + def add( + self, + other: Self | pd.DataFrame | ListLike | dict[str, Any], + inplace: bool = True, + ) -> Self: + obj = self._get_obj(inplace) + if isinstance(other, obj.__class__): + new_agents = other.agents + elif isinstance(other, pd.DataFrame): + new_agents = other + if "unique_id" != other.index.name: + try: + new_agents.set_index("unique_id", inplace=True, drop=True) + except KeyError: + new_agents["unique_id"] = obj.random.random(len(other)) * 10**8 + elif isinstance(other, dict): + if "unique_id" not in other: + index = obj.random.random(len(other)) * 10**8 + if not isinstance(other["unique_id"], ListLike): + index = [other["unique_id"]] + else: + index = other["unique_id"] + new_agents = ( + pd.DataFrame(other, index=pd.Index(index)) + .reset_index(drop=True) + .set_index("unique_id") + ) + else: # ListLike + if len(other) == len(obj._agents.columns): + # data missing unique_id + new_agents = pd.DataFrame([other], columns=obj._agents.columns) + new_agents["unique_id"] = obj.random.random(1) * 10**8 + elif len(other) == len(obj._agents.columns) + 1: + new_agents = pd.DataFrame( + [other], columns=["unique_id"] + obj._agents.columns.values.tolist() + ) + else: + raise ValueError( + "Length of data must match the number of columns in the AgentSet if being added as a ListLike." + ) + new_agents.set_index("unique_id", inplace=True, drop=True) + obj._agents = pd.concat([obj._agents, new_agents]) + return obj - def remove(self, id: int) -> Self: - self.agents.drop(id, inplace=True) - return self + def remove(self, id: PandasMaskLike, inplace: bool = True) -> Self: + initial_len = len(self._agents) + obj = self._get_obj(inplace) + mask = obj._get_bool_mask(id) + remove_ids = obj._agents[mask].index + obj._agents.drop(remove_ids, inplace=True) + if len(obj._agents) == initial_len: + raise KeyError(f"IDs {id} not found in agent set.") + return obj class AgentSetPolars(AgentSetDF): - agents: pl.DataFrame + _agents: pl.DataFrame _mask: pl.Expr | pl.Series """A polars-based implementation of the AgentSet.""" @@ -438,9 +1140,15 @@ def __init__(self, model: ModelDF): The model to which the AgentSetDF belongs. """ super().__init__(model) - self.agents = pl.DataFrame(schema={"unique_id": pl.String}) + self._agents = pl.DataFrame(schema={"unique_id": pl.String}) self._mask = pl.repeat(True, len(self.agents)) + @property + def agents(self) -> pl.DataFrame: + if self._agents is None: + self._agents = pl.DataFrame(schema={"unique_id": pl.String}) + return self._agents + @property def active_agents(self) -> pl.DataFrame: return self.agents.filter(self._mask) @@ -483,14 +1191,14 @@ def sort( ) -> Self: """Sort the agents in the agent set based on the given criteria. - Parameters: + Parameters ---------- by (IntoExpr | Iterable[IntoExpr]): The attribute(s) to sort by. more_by (IntoExpr): Additional attributes to sort by. descending (bool | Sequence[bool]): Whether to sort in descending order. nulls_last (bool): Whether to place null values last. - Returns: + Returns ---------- AgentSetDF: The sorted agent set. """ @@ -499,14 +1207,14 @@ def sort( ) return self - def get_attribute(self, attr_name: str) -> pl.Series: - return self.agents.filter(self._mask)[attr_name] + def get_attribute(self, attr_names: str) -> pl.Series: + return self.agents.filter(self._mask)[attr_names] - def set_attribute(self, attr_name: str, value: Any) -> Self: + def set_attribute(self, attr_names: str, value: Any) -> Self: if type(value) == pl.Series: - self.agents.filter(self._mask).with_columns(**{attr_name: value}) + self.agents.filter(self._mask).with_columns(**{attr_names: value}) else: - self.agents.filter(self._mask).with_columns(**{attr_name: pl.lit(value)}) + self.agents.filter(self._mask).with_columns(**{attr_names: pl.lit(value)}) return self def add( @@ -521,7 +1229,7 @@ def add( ) -> Self: """Adds new agents to the agent set. - Parameters: + Parameters ---------- n : int The number of agents to add. @@ -546,7 +1254,7 @@ def add( nan_to_null : bool, default False If the data comes from one or more numpy arrays, can optionally convert input data np.nan values to null instead. This is a no-op for all other input data. - Returns: + Returns ---------- AgentSetPolars: The updated agent set. """ @@ -594,6 +1302,15 @@ def __init__(self, model: ModelDF) -> None: super().__init__(model) self.agentsets = [] + def __len__(self) -> int: + return sum(len(agentset.agents) for agentset in self.agentsets) + + def __repr__(self): + return self.agentsets.__repr__() + + def __str__(self) -> str: + return self.agentsets.__str__() + def sort( self, by: str | Sequence[str], @@ -650,17 +1367,17 @@ def add(self, agentsets: AgentSetDF | list[AgentSetDF]) -> Self: return self @abstractmethod - def to_frame(self) -> DataFrameLike: + def to_frame(self) -> DataFrame: """Convert the AgentsDF to a single DataFrame. Returns ------- - DataFrameLike + DataFrame A DataFrame containing all agents from all AgentSetDFs. """ pass - def get_agents_of_type(self, agent_type: type) -> AgentSetDF: + def get_agents_of_type(self, agent_type: type[AgentSetDF]) -> AgentSetDF: """Retrieve the AgentSetDF of a specified type. Parameters @@ -678,9 +1395,9 @@ def get_agents_of_type(self, agent_type: type) -> AgentSetDF: return agentset raise ValueError(f"No AgentSetDF of type {agent_type} found.") - def set_attribute(self, attr_name: str, value: Any) -> Self: + def set_attribute(self, attr_names: str, value: Any) -> Self: self.agentsets = [ - agentset.set_attribute(attr_name, value) for agentset in self.agentsets + agentset.set_attribute(attr_names, value) for agentset in self.agentsets ] return self @@ -757,21 +1474,13 @@ def select( self.agentsets = new_agentsets return self - def get_attribute(self, attr_name: str) -> pd.Series[Any]: + def get_attribute(self, attr_names: str) -> pd.Series[Any]: return pd.concat( - [agentset.get_attribute(attr_name) for agentset in self.agentsets] + [agentset.get_attribute(attr_names) for agentset in self.agentsets] ) def add(self, agentsets: AgentSetPandas | list[AgentSetPandas]) -> Self: - if isinstance(agentsets, list) and not all( - isinstance(agentset, AgentSetPandas) for agentset in agentsets - ): - raise ValueError("All agentsets must be of type AgentSetPandas.") - elif not isinstance(agentsets, AgentSetPandas): - raise ValueError( - "agentsets must be of type AgentSetPandas or list[AgentSetPandas]." - ) - return super().add(agentsets) + return super().add(agentsets) # type: ignore class AgentsPolars(AgentsDF): @@ -825,18 +1534,10 @@ def select( self.agentsets = new_agentsets return self - def get_attribute(self, attr_name: str) -> pl.Series: + def get_attribute(self, attr_names: str) -> pl.Series: return pl.concat( - [agentset.get_attribute(attr_name) for agentset in self.agentsets] + [agentset.get_attribute(attr_names) for agentset in self.agentsets] ) def add(self, agentsets: AgentSetPolars | list[AgentSetPolars]) -> Self: - if isinstance(agentsets, list) and not all( - isinstance(agentset, AgentSetPolars) for agentset in agentsets - ): - raise ValueError("All agentsets must be of type AgentSetPolars.") - elif not isinstance(agentsets, AgentSetPolars): - raise ValueError( - "agentsets must be of type AgentSetPolars or list[AgentSetPolars]." - ) - return super().add(agentsets) + return super().add(agentsets) # type: ignore #child classes are not checked? From a79d1b522767b49b0fe87643ff0fbc6149da9729 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:23:10 +0200 Subject: [PATCH 2/2] implementation of test for methods --- tests/test_agent.py | 410 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 410 insertions(+) create mode 100644 tests/test_agent.py diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000..b3375ae --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,410 @@ +from copy import copy, deepcopy + +import pandas as pd +import pytest +import typeguard as tg +from mesa import Model +from numpy.random import Generator + +from mesa_frames import ( + AgentSetPandas, + AgentSetPolars, + AgentsPandas, + AgentsPolars, + ModelDF, + agent, +) + + +@tg.typechecked +class ExampleAgentSet(AgentSetPandas): + def __init__(self, model: ModelDF): + self.starting_wealth = pd.Series([1, 2, 3, 4], name="wealth") + + def add_wealth(self, amount: int) -> None: + self.agents["wealth"] += amount + + +@pytest.fixture +def fix1_AgentSetPandas(): + model = ModelDF() + agents = ExampleAgentSet(model) + agents.add({"unique_id": [0, 1, 2, 3]}) + agents.agents["wealth"] = agents.starting_wealth + agents.agents["age"] = [10, 20, 30, 40] + return agents + + +@pytest.fixture +def fix2_AgentSetPandas(): + model = ModelDF() + agents = ExampleAgentSet(model) + agents.add({"unique_id": [10, 11, 12, 13]}) + agents.agents["wealth"] = agents.starting_wealth + 10 + agents.agents["age"] = [100, 200, 300, 400] + + return agents + + +class Test_AgentSetPandas: + def test__init__(self): + model = ModelDF() + forbidden_model = Model() + try: + agents = ExampleAgentSet(forbidden_model) + except Exception as e: + assert type(e) == tg.TypeCheckError + agents = ExampleAgentSet(model) + assert agents.model == model + assert isinstance(agents.agents, pd.DataFrame) + assert isinstance(agents._mask, pd.Series) + assert isinstance(agents.random, Generator) + assert agents.starting_wealth.tolist() == [1, 2, 3, 4] + + def test__add__(self, fix1_AgentSetPandas, fix2_AgentSetPandas): + agents = fix1_AgentSetPandas + agents2 = fix2_AgentSetPandas + + # Test with two AgentSetPandas + agents3 = agents + agents2 + assert agents3.agents.index.tolist() == [0, 1, 2, 3, 10, 11, 12, 13] + + # Test with an AgentSetPandas and a DataFrame + agents3 = agents + agents2.agents + assert agents3.agents.index.tolist() == [0, 1, 2, 3, 10, 11, 12, 13] + + # Test with an AgentSetPandas and a list + agents3 = agents + [10, 5] # 10 should be unique id and 5 should be wealth + assert agents3.agents.index.tolist()[:-1] == [0, 1, 2, 3] + assert len(agents3.agents) == 5 + assert agents3.agents.wealth.tolist() == [1, 2, 3, 4, 10] + + # Test with an AgentSetPandas and a dict + agents3 = agents + {"unique_id": 10, "wealth": 5} + assert agents3.agents.index.tolist() == [0, 1, 2, 3, 10] + assert agents3.agents.wealth.tolist() == [1, 2, 3, 4, 5] + + def test__contains__(self, fix1_AgentSetPandas): + # Test with a single value + agents = fix1_AgentSetPandas + assert 0 in agents + assert 4 not in agents + + def test__copy__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents2 = copy(agents) + agents[0, "wealth"] = 5 + assert agents is not agents2 + assert agents[0, "wealth"].values == agents2[0, "wealth"].values + assert agents2.model is agents.model + + def test__deepcopy__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents2 = deepcopy(agents) + agents[0, "wealth"] = 5 + assert agents is not agents2 + assert agents[0, "wealth"].values != agents2[0, "wealth"].values + assert agents2.model is agents.model + + def test__getattr__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + assert isinstance(agents.model, ModelDF) + assert agents.wealth.tolist() == [1, 2, 3, 4] + + def test__getitem__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + # Testing with a string + assert agents["wealth"].tolist() == [1, 2, 3, 4] + + # Test with a tuple[MaskLike, str] + assert agents[0, "wealth"].values == 1 + + # Test with a list[str] + assert agents[["wealth", "age"]].columns.tolist() == ["wealth", "age"] + + # Testing with a tuple[MaskLike, list[str]] + result = agents[0, ["wealth", "age"]] + assert result["wealth"].values.tolist() == [1] + assert result["age"].values.tolist() == [10] + + def test__iadd__(self, fix1_AgentSetPandas, fix2_AgentSetPandas): + agents = deepcopy(fix1_AgentSetPandas) + agents2 = fix2_AgentSetPandas + + # Test with two AgentSetPandas + agents += agents2 + assert agents.agents.index.tolist() == [0, 1, 2, 3, 10, 11, 12, 13] + + # Test with an AgentSetPandas and a DataFrame + agents = deepcopy(fix1_AgentSetPandas) + agents += agents2.agents + assert agents.agents.index.tolist() == [0, 1, 2, 3, 10, 11, 12, 13] + + # Test with an AgentSetPandas and a list + agents = deepcopy(fix1_AgentSetPandas) + agents += [10, 5] + assert agents.agents.index.tolist()[:-1] == [0, 1, 2, 3] + assert len(agents.agents) == 5 + assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 10] + + # Test with an AgentSetPandas and a dict + agents = deepcopy(fix1_AgentSetPandas) + agents += {"unique_id": 10, "wealth": 5} + assert agents.agents.index.tolist() == [0, 1, 2, 3, 10] + + def test__iter__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + for i, agent in agents: + assert isinstance(agent, pd.Series) + assert agent["wealth"] == i + 1 + + def test__isub__(self, fix1_AgentSetPandas): + agents = deepcopy(fix1_AgentSetPandas) + + # Test with two AgentSetPandas + agents -= agents + assert agents.agents.empty + + # Test with an AgentSetPandas and a DataFrame + agents = deepcopy(fix1_AgentSetPandas) + agents -= agents.agents + assert agents.agents.empty + + def test__len__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + assert len(agents) == 4 + + def test__repr__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + repr(agents) + + def test__reversed__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + reversed_wealth = [] + for i, agent in reversed(agents): + reversed_wealth.append(agent["wealth"]) + assert reversed_wealth == [4, 3, 2, 1] + + def test__setitem__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + agents = deepcopy(agents) # To test passing through a df later + + # Test with key=str, value=Any + agents["wealth"] = 0 + assert agents.agents.wealth.tolist() == [0, 0, 0, 0] + + # Test with key=list[str], value=Any + agents[["wealth", "age"]] = 1 + assert agents.agents.wealth.tolist() == [1, 1, 1, 1] + assert agents.agents.age.tolist() == [1, 1, 1, 1] + + # Test with key=tuple, value=Any + agents[0, "wealth"] = 5 + assert agents.agents.wealth.tolist() == [5, 1, 1, 1] + + # Test with key=MaskLike, value=Any + agents[0] = [9, 99] + assert agents.agents.loc[0, "wealth"] == 9 + assert agents.agents.loc[0, "age"] == 99 + + def test__str__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + str(agents) + + def test__sub__(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents2 = agents - agents + assert agents2.agents.empty + assert agents.agents.wealth.tolist() == [1, 2, 3, 4] + + def test_get_object(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + assert agents._get_obj(inplace=True) is agents + assert agents._get_obj(inplace=False) is not agents + + def test_active_agents(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents.active_agents = [0, 1] + assert agents.active_agents.index.to_list() == [0, 1] + + def test_add(self, fix1_AgentSetPandas, fix2_AgentSetPandas): + agents = fix1_AgentSetPandas + agents2 = fix2_AgentSetPandas + + # Test with self + result = agents.add(agents2, inplace=False) + assert result.agents.index.to_list() == [0, 1, 2, 3, 10, 11, 12, 13] + + # Test with a DataFrame + result = agents.add(agents2.agents, inplace=False) + assert result.agents.index.to_list() == [0, 1, 2, 3, 10, 11, 12, 13] + + # Test with a list + result = agents.add([10, 5, 10], inplace=False) + assert result.agents.index.to_list() == [0, 1, 2, 3, 10] + assert result.agents.wealth.to_list() == [1, 2, 3, 4, 5] + assert result.agents.age.to_list() == [10, 20, 30, 40, 10] + + # Test with a dict + agents.add({"unique_id": [4, 5], "wealth": [5, 6], "age": [50, 60]}) + assert agents.agents.wealth.tolist() == [1, 2, 3, 4, 5, 6] + assert agents.agents.index.tolist() == [0, 1, 2, 3, 4, 5] + assert agents.agents.age.tolist() == [10, 20, 30, 40, 50, 60] + + def test_contains(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + # Test with a single value + assert agents.contains(0) + assert not agents.contains(4) + + # Test with a list + assert agents.contains([0, 1]).values.tolist() == [True, True] + + # Test with a pd.DataFrame + assert agents.contains(pd.DataFrame({"unique_id": [0, 4]})).values.tolist() == [True, False] + + def test_copy(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + # Test with deep=False + agents2 = agents.copy(deep=False) + agents2[0, "wealth"] = 5 + assert agents[0, "wealth"].values == agents2[0, "wealth"].values + assert agents2.model is agents.model + + # Test with deep=True + agents2 = fix1_AgentSetPandas.copy(deep=True) + agents2[0, "wealth"] = 3 + assert agents[0, "wealth"].values != agents2[0, "wealth"].values + assert agents2.model is agents.model + + # Test by skipping starting_wealth + agents2 = agents.copy(skip=["starting_wealth"]) + with pytest.raises(KeyError) as e: + agents2.starting_wealth + assert "starting_wealth" in str(e) + + def test_discard(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + # Test with a single value + result = agents.discard(0, inplace=False) + assert result.agents.index.to_list() == [1, 2, 3] + + # Test with a list + result = agents.discard([0, 1], inplace=False) + assert agents.agents.index.tolist() == [2, 3] + + # Test with a pd.DataFrame + result = agents.discard(pd.DataFrame({"unique_id": [0, 1]}), inplace=False) + assert result.agents.index.to_list() == [2, 3] + + # Test with active_agents + agents.active_agents = [0, 1] + result = agents.discard("active", inplace=False) + assert result.agents.index.to_list() == [2, 3] + + def test_do(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents.do("add_wealth", 1) + assert agents.agents.wealth.tolist() == [2, 3, 4, 5] + assert agents.do("add_wealth", 1, return_results=True) == None + + def test_get_attribute(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + # Test with a single attribute + assert agents.get_attribute("wealth").tolist() == [1, 2, 3, 4] + + # Test with a list of attributes + assert agents.get_attribute(["wealth", "age"]).columns.tolist() == [ + "wealth", + "age", + ] + + # Test with a single attribute and a mask + selected = agents.select(agents["wealth"] > 1, inplace=False) + assert selected.get_attribute("wealth", mask="active").tolist() == [2, 3, 4] + + def test_remove(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents.remove([0, 1]) + assert agents.agents.index.tolist() == [2, 3] + with pytest.raises(KeyError) as e: + agents.remove([1]) + assert "1" in str(e) + + def test_select(self, fix1_AgentSetPandas): + agents: ExampleAgentSet = fix1_AgentSetPandas + + # Test with default arguments. Should select all agents + selected = agents.select(inplace=False) + assert selected.active_agents.wealth.tolist() == agents.agents.wealth.tolist() + + # Test with a pd.Series[bool] + mask = pd.Series([True, False, True, True]) + selected = agents.select(mask, inplace=False) + assert selected.active_agents.index.tolist() == [0, 2, 3] + + # Test with a ListLike + mask = [0, 2] + selected = agents.select(mask, inplace=False) + assert selected.active_agents.index.tolist() == [0, 2] + + # Test with a pd.DataFrame + mask = pd.DataFrame({"unique_id": [0, 1]}) + selected = agents.select(mask, inplace=False) + assert selected.active_agents.index.tolist() == [0, 1] + + # Test with filter_func + def filter_func(agentset: AgentSetPandas) -> pd.Series: + return agentset.agents.wealth > 1 + + selected = agents.select(filter_func=filter_func, inplace=False) + assert selected.active_agents.index.tolist() == [1, 2, 3] + + # Test with n + selected = agents.select(n=3, inplace=False) + assert len(selected.active_agents) == 3 + + # Test with n, filter_func and mask + mask = pd.Series([True, False, True, True]) + selected = agents.select(mask, filter_func=filter_func, n=1, inplace=False) + assert any(el in selected.active_agents.index.tolist() for el in [2, 3]) + + def test_set_attribute(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + + # Test with a single attribute + agents.set_attribute("wealth", 0) + assert agents.agents.wealth.tolist() == [0, 0, 0, 0] + + # Test with a list of attributes + agents.set_attribute(["wealth", "age"], 1) + + # Test with a single attribute and a mask + selected = agents.select(agents["wealth"] > 1, inplace=False) + selected.set_attribute("wealth", 0, mask="active") + assert selected.agents.wealth.tolist() == [1, 0, 0, 0] + + # Test with a dictionary + agents.set_attribute({"wealth": 10, "age": 20}) + assert agents.agents.wealth.tolist() == [10, 10, 10, 10] + assert agents.agents.age.tolist() == [20, 20, 20, 20] + + def test_shuffle(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + for _ in range(10): + original_order = agents.agents.index.tolist() + agents.shuffle() + if original_order != agents.agents.index.tolist(): + return + assert False + + def test_sort(self, fix1_AgentSetPandas): + agents = fix1_AgentSetPandas + agents.sort("wealth", ascending=False) + assert agents.agents.wealth.tolist() == [4, 3, 2, 1]