Skip to content

Commit

Permalink
minor fixes for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamamer20 committed Jul 2, 2024
1 parent 7ee18ef commit d11fb3a
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 18 deletions.
2 changes: 1 addition & 1 deletion mesa_frames/abstract/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from numpy.random import Generator

from mesa_frames.abstract.mixin import CopyMixin
from mesa_frames.concrete.model import ModelDF
from mesa_frames.types import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series

if TYPE_CHECKING:
from mesa_frames.concrete.agents import AgentSetDF
from mesa_frames.concrete.model import ModelDF


class AgentContainer(CopyMixin):
Expand Down
28 changes: 17 additions & 11 deletions mesa_frames/concrete/agentset_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,14 +303,13 @@ def _concatenate_agentsets(
else:
final_df = pd.concat([obj._agents for obj in agentsets])
final_mask = pd.concat([obj._mask for obj in agentsets])
new_obj = self._get_obj(inplace=False)
new_obj._agents = final_df
new_obj._mask = final_mask
self._agents = final_df
self._mask = final_mask
if not isinstance(original_masked_index, type(None)):
ids_to_remove = original_masked_index.difference(self._agents.index)
if not ids_to_remove.empty:
new_obj.remove(ids_to_remove, inplace=True)
return new_obj
self.remove(ids_to_remove, inplace=True)
return self

def _get_bool_mask(
self,
Expand Down Expand Up @@ -390,12 +389,19 @@ def _update_mask(
new_active_indices: pd.Index | None = None,
) -> None:
# Update the mask with the old active agents and the new agents
self._mask = pd.Series(
self._agents.index.isin(original_active_indices)
| self._agents.index.isin(new_active_indices),
index=self._agents.index,
dtype=pd.BooleanDtype(),
)
if new_active_indices is None:
self._mask = pd.Series(
self._agents.index.isin(original_active_indices),
index=self._agents.index,
dtype=pd.BooleanDtype(),
)
else:
self._mask = pd.Series(
self._agents.index.isin(original_active_indices)
| self._agents.index.isin(new_active_indices),
index=self._agents.index,
dtype=pd.BooleanDtype(),
)

def __getattr__(self, name: str) -> Any:
super().__getattr__(name)
Expand Down
4 changes: 2 additions & 2 deletions mesa_frames/concrete/agentset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import polars as pl
from polars.type_aliases import IntoExpr

from mesa_frames import AgentSetDF
from mesa_frames.concrete.agents import AgentSetDF
from mesa_frames.types import PolarsIdsLike, PolarsMaskLike

if TYPE_CHECKING:
Expand Down Expand Up @@ -209,7 +209,7 @@ def remove(self, ids: PolarsIdsLike, inplace: bool = True) -> Self:
if len(obj._agents) == initial_len:
raise KeyError(f"IDs {ids} not found in agent set.")

if isinstance(obj.mask, pl.Series):
if isinstance(obj._mask, pl.Series):
obj._update_mask(original_active_indices)
return obj

Expand Down
8 changes: 5 additions & 3 deletions mesa_frames/concrete/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import Sequence
from typing import Any
from typing import TYPE_CHECKING, Any

import numpy as np

from mesa_frames.abstract.agents import AgentSetDF
from mesa_frames.concrete.agents import AgentsDF

if TYPE_CHECKING:
from mesa_frames.abstract.agents import AgentSetDF


class ModelDF:
"""Base class for models in the mesa-frames library.
Expand Down Expand Up @@ -75,7 +77,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self.current_id = 0
self._agents = AgentsDF()

def get_agents_of_type(self, agent_type: type) -> AgentSetDF:
def get_agents_of_type(self, agent_type: type) -> "AgentSetDF":
"""Retrieve the AgentSetDF of a specified type.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion tests/test_agentset_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test__init__(self):
assert agents.model == model
assert isinstance(agents.agents, pl.DataFrame)
assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3]
assert isinstance(agents._mask, pl.Expr)
assert isinstance(agents._mask, pl.Series)
assert isinstance(agents.random, Generator)
assert agents.starting_wealth.to_list() == [1, 2, 3, 4]

Expand Down

0 comments on commit d11fb3a

Please sign in to comment.