From d11fb3a6bb77cc12b03d8ca3be802d22f4cf3b17 Mon Sep 17 00:00:00 2001 From: Adam Amer <136176500+adamamer20@users.noreply.github.com> Date: Tue, 2 Jul 2024 14:25:03 +0200 Subject: [PATCH] minor fixes for tests --- mesa_frames/abstract/agents.py | 2 +- mesa_frames/concrete/agentset_pandas.py | 28 +++++++++++++++---------- mesa_frames/concrete/agentset_polars.py | 4 ++-- mesa_frames/concrete/model.py | 8 ++++--- tests/test_agentset_polars.py | 2 +- 5 files changed, 26 insertions(+), 18 deletions(-) diff --git a/mesa_frames/abstract/agents.py b/mesa_frames/abstract/agents.py index 6ab56ee..e82d877 100644 --- a/mesa_frames/abstract/agents.py +++ b/mesa_frames/abstract/agents.py @@ -8,11 +8,11 @@ from numpy.random import Generator from mesa_frames.abstract.mixin import CopyMixin -from mesa_frames.concrete.model import ModelDF from mesa_frames.types import BoolSeries, DataFrame, IdsLike, Index, MaskLike, Series if TYPE_CHECKING: from mesa_frames.concrete.agents import AgentSetDF + from mesa_frames.concrete.model import ModelDF class AgentContainer(CopyMixin): diff --git a/mesa_frames/concrete/agentset_pandas.py b/mesa_frames/concrete/agentset_pandas.py index faace82..f8e9a57 100644 --- a/mesa_frames/concrete/agentset_pandas.py +++ b/mesa_frames/concrete/agentset_pandas.py @@ -303,14 +303,13 @@ def _concatenate_agentsets( else: final_df = pd.concat([obj._agents for obj in agentsets]) final_mask = pd.concat([obj._mask for obj in agentsets]) - new_obj = self._get_obj(inplace=False) - new_obj._agents = final_df - new_obj._mask = final_mask + self._agents = final_df + self._mask = final_mask if not isinstance(original_masked_index, type(None)): ids_to_remove = original_masked_index.difference(self._agents.index) if not ids_to_remove.empty: - new_obj.remove(ids_to_remove, inplace=True) - return new_obj + self.remove(ids_to_remove, inplace=True) + return self def _get_bool_mask( self, @@ -390,12 +389,19 @@ def _update_mask( new_active_indices: pd.Index | None = None, ) -> None: # Update the mask with the old active agents and the new agents - self._mask = pd.Series( - self._agents.index.isin(original_active_indices) - | self._agents.index.isin(new_active_indices), - index=self._agents.index, - dtype=pd.BooleanDtype(), - ) + if new_active_indices is None: + self._mask = pd.Series( + self._agents.index.isin(original_active_indices), + index=self._agents.index, + dtype=pd.BooleanDtype(), + ) + else: + self._mask = pd.Series( + self._agents.index.isin(original_active_indices) + | self._agents.index.isin(new_active_indices), + index=self._agents.index, + dtype=pd.BooleanDtype(), + ) def __getattr__(self, name: str) -> Any: super().__getattr__(name) diff --git a/mesa_frames/concrete/agentset_polars.py b/mesa_frames/concrete/agentset_polars.py index 143dc8b..df0f575 100644 --- a/mesa_frames/concrete/agentset_polars.py +++ b/mesa_frames/concrete/agentset_polars.py @@ -4,7 +4,7 @@ import polars as pl from polars.type_aliases import IntoExpr -from mesa_frames import AgentSetDF +from mesa_frames.concrete.agents import AgentSetDF from mesa_frames.types import PolarsIdsLike, PolarsMaskLike if TYPE_CHECKING: @@ -209,7 +209,7 @@ def remove(self, ids: PolarsIdsLike, inplace: bool = True) -> Self: if len(obj._agents) == initial_len: raise KeyError(f"IDs {ids} not found in agent set.") - if isinstance(obj.mask, pl.Series): + if isinstance(obj._mask, pl.Series): obj._update_mask(original_active_indices) return obj diff --git a/mesa_frames/concrete/model.py b/mesa_frames/concrete/model.py index 802b845..c2c3964 100644 --- a/mesa_frames/concrete/model.py +++ b/mesa_frames/concrete/model.py @@ -1,11 +1,13 @@ from collections.abc import Sequence -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -from mesa_frames.abstract.agents import AgentSetDF from mesa_frames.concrete.agents import AgentsDF +if TYPE_CHECKING: + from mesa_frames.abstract.agents import AgentSetDF + class ModelDF: """Base class for models in the mesa-frames library. @@ -75,7 +77,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.current_id = 0 self._agents = AgentsDF() - def get_agents_of_type(self, agent_type: type) -> AgentSetDF: + def get_agents_of_type(self, agent_type: type) -> "AgentSetDF": """Retrieve the AgentSetDF of a specified type. Parameters diff --git a/tests/test_agentset_polars.py b/tests/test_agentset_polars.py index 9ee828f..3e54e6d 100644 --- a/tests/test_agentset_polars.py +++ b/tests/test_agentset_polars.py @@ -46,7 +46,7 @@ def test__init__(self): assert agents.model == model assert isinstance(agents.agents, pl.DataFrame) assert agents.agents["unique_id"].to_list() == [0, 1, 2, 3] - assert isinstance(agents._mask, pl.Expr) + assert isinstance(agents._mask, pl.Series) assert isinstance(agents.random, Generator) assert agents.starting_wealth.to_list() == [1, 2, 3, 4]