Skip to content

Commit

Permalink
bug fixes, added tests for AgentsDF
Browse files Browse the repository at this point in the history
  • Loading branch information
adamamer20 committed Jul 4, 2024
1 parent 70821db commit d5201f8
Show file tree
Hide file tree
Showing 10 changed files with 1,319 additions and 209 deletions.
22 changes: 8 additions & 14 deletions mesa_frames/abstract/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from abc import abstractmethod
from collections.abc import Callable, Collection, Iterable, Iterator, Sequence
from contextlib import suppress
from typing import TYPE_CHECKING, Literal

from numpy.random import Generator
from typing_extensions import Any, Self, overload

from typing import TYPE_CHECKING, Literal

from mesa_frames.abstract.mixin import CopyMixin
from mesa_frames.types import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series

Expand Down Expand Up @@ -86,7 +85,7 @@ def discard(self, agents, inplace: bool = True) -> Self:
----------
Self
"""
with suppress(KeyError):
with suppress(KeyError, ValueError):
return self.remove(agents, inplace=inplace)
return self._get_obj(inplace)

Expand Down Expand Up @@ -363,22 +362,20 @@ def sort(
def __add__(self, other) -> Self:
return self.add(agents=other, inplace=False)

def __contains__(self, id: int) -> bool:
def __contains__(self, agents: int | IdsLike | AgentSetDF) -> bool:
"""Check if an agent is in the AgentContainer.
Parameters
----------
id : Hashable
The ID(s) to check for.
id : int | IdsLike | AgentSetDF
The ID(s) or AgentSetDF to check for.
Returns
-------
bool
True if the agent is in the AgentContainer, False otherwise.
"""
if not isinstance(id, int):
raise TypeError("id must be an integer")
return self.contains(ids=id)
return self.contains(agents=agents)

def __getitem__(
self,
Expand Down Expand Up @@ -511,7 +508,7 @@ def __getattr__(self, name: str) -> Any | dict[str, Any]:
"""

@abstractmethod
def __iter__(self) -> Iterator:
def __iter__(self) -> Iterator[dict[str, Any]]:
"""Iterate over the agents in the AgentContainer.
Returns
Expand Down Expand Up @@ -973,15 +970,12 @@ def __getitem__(
),
) -> Series | DataFrame:
attr = super().__getitem__(key)
assert isinstance(attr, (Series, DataFrame))
assert isinstance(attr, (Series, DataFrame, Index))
return attr

def __len__(self) -> int:
return len(self._agents)

def __iter__(self) -> Iterator:
return iter(self._agents)

def __repr__(self) -> str:
return f"{self.__class__.__name__}\n {str(self._agents)}"

Expand Down
12 changes: 10 additions & 2 deletions mesa_frames/abstract/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class CopyMixin(ABC):
_description_
"""

_copy_with_method: dict[str, tuple[str, list[str]]]
_copy_with_method: dict[str, tuple[str, list[str]]] = {}
_copy_only_reference: list[str] = [
"_model",
]
Expand All @@ -38,6 +38,7 @@ def copy(
self,
deep: bool = False,
memo: dict | None = None,
skip: list[str] | None = None,
) -> Self:
"""Create a copy of the Class.
Expand All @@ -48,10 +49,12 @@ def copy(
If True, all attributes of the AgentContainer will be recursively copied (except attributes in self._copy_reference_only).
If False, only the top-level attributes will be copied.
Defaults to False.
memo : dict | None, optional
A dictionary used to track already copied objects during deep copy.
Defaults to None.
skip : list[str] | None, optional
A list of attribute names to skip during the copy process.
Defaults to None.
Returns
-------
Expand All @@ -61,6 +64,9 @@ def copy(
cls = self.__class__
obj = cls.__new__(cls)

if skip is None:
skip = []

if deep:
if not memo:
memo = {}
Expand All @@ -71,13 +77,15 @@ def copy(
for k, v in attributes.items()
if k not in self._copy_with_method
and k not in self._copy_only_reference
and k not in skip
]
else:
[
setattr(obj, k, copy(v))
for k, v in self.__dict__.items()
if k not in self._copy_with_method
and k not in self._copy_only_reference
and k not in skip
]

# Copy attributes with a reference only
Expand Down
Loading

0 comments on commit d5201f8

Please sign in to comment.