Skip to content

Commit

Permalink
Adding pre-step, post-step, pre-reset, and post-reset callbacks to th…
Browse files Browse the repository at this point in the history
…e environment wrapper class
  • Loading branch information
Biological Compatibility Benchmarks committed Dec 7, 2024
1 parent 81ed30d commit a2a4223
Showing 1 changed file with 115 additions and 63 deletions.
178 changes: 115 additions & 63 deletions aintelope/environments/savanna_safetygrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@
Dict[AgentId, Info],
]

reset_count = 0 # for debugging


class GridworldZooBaseEnv:
metadata = {
Expand Down Expand Up @@ -253,6 +251,10 @@ def __init__(
self._combine_interoception_and_vision = self.metadata[
"combine_interoception_and_vision"
]
self._pre_reset_callback2 = None
self._post_reset_callback2 = None
self._pre_step_callback2 = None
self._post_step_callback2 = None

def init_observation_spaces(self, parent_observation_spaces, infos):
# for @zoo-api
Expand Down Expand Up @@ -413,6 +415,7 @@ def filter_info(self, agent: str, info: dict):
INFO_AGENT_INTEROCEPTION_VECTOR, # keeping interoception available in info since in observation it may be either located in its own vector or be part of the vision. That make access to this data cumbersome when writing hardcoded rules. Accessing via info argument is more convenient in such cases.
INFO_AGENT_INTEROCEPTION_ORDER,
ACTION_RELATIVE_COORDINATE_MAP,
INFO_REWARD_DICT, # keep reward dict for case the score is scalarised
]
result = {key: value for key, value in info.items() if key in allowed_keys}

Expand All @@ -430,29 +433,29 @@ def filter_infos(self, infos: dict):
for agent, agent_info in infos.items()
}

@property
def grass_patches(self):
any_agent = self._last_infos[
"agent_0"
] # any agent is good here since we are using global coordinates here
coordinates = any_agent[INFO_OBSERVATION_COORDINATES].get(FOOD_CHR, [])
if len(coordinates) > 0:
grass_patches = np.array(coordinates)
else:
grass_patches = np.zeros([0, 2])
return grass_patches

@property
def water_holes(self):
any_agent = self._last_infos[
"agent_0"
] # any agent is good here since we are using global coordinates here
coordinates = any_agent[INFO_OBSERVATION_COORDINATES].get(DRINK_CHR, [])
if len(coordinates) > 0:
water_holes = np.array(coordinates)
else:
water_holes = np.zeros([0, 2])
return water_holes
# @property
# def grass_patches(self):
# any_agent = self._last_infos[
# "agent_0"
# ] # any agent is good here since we are using global coordinates here
# coordinates = any_agent[INFO_OBSERVATION_COORDINATES].get(FOOD_CHR, [])
# if len(coordinates) > 0:
# grass_patches = np.array(coordinates)
# else:
# grass_patches = np.zeros([0, 2])
# return grass_patches

# @property
# def water_holes(self):
# any_agent = self._last_infos[
# "agent_0"
# ] # any agent is good here since we are using global coordinates here
# coordinates = any_agent[INFO_OBSERVATION_COORDINATES].get(DRINK_CHR, [])
# if len(coordinates) > 0:
# water_holes = np.array(coordinates)
# else:
# water_holes = np.zeros([0, 2])
# return water_holes

def observe_from_location(
self, agents_coordinates: Dict, agents_directions: Dict = None
Expand All @@ -473,30 +476,30 @@ def observe_from_location(
def observation_space(self, agent):
return self.transformed_observation_spaces[agent]

# called by DQNLightning
def state_to_namedtuple(self, state: npt.NDArray[ObservationFloat]) -> NamedTuple:
"""Method to convert a state array into a named tuple."""
agent_coords = {
"agent_coords": state[:2]
} # TODO: make it dependant on number of agents
grass_patches_coords = {}
gp_offset = 2
water_holes_coords = {}
wh_offset = 2 + self.metadata["amount_grass_patches"] * 2
for i in range(self.metadata["amount_grass_patches"]):
grass_patches_coords[f"grass_patch_{i}"] = state[
gp_offset + i : gp_offset + i + 2
]
for i in range(self.metadata["amount_water_holes"]):
water_holes_coords[f"water_hole_{i}"] = state[
wh_offset + i : wh_offset + i + 2
]

keys = (
list(agent_coords) + list(grass_patches_coords) + list(water_holes_coords)
)
StateTuple = namedtuple("StateTuple", {k: np.ndarray for k in keys})
return StateTuple(**agent_coords, **grass_patches_coords, **water_holes_coords)
## called by DQNLightning
# def state_to_namedtuple(self, state: npt.NDArray[ObservationFloat]) -> NamedTuple:
# """Method to convert a state array into a named tuple."""
# agent_coords = {
# "agent_coords": state[:2]
# } # TODO: make it dependant on number of agents
# grass_patches_coords = {}
# gp_offset = 2
# water_holes_coords = {}
# wh_offset = 2 + self.metadata["amount_grass_patches"] * 2
# for i in range(self.metadata["amount_grass_patches"]):
# grass_patches_coords[f"grass_patch_{i}"] = state[
# gp_offset + i : gp_offset + i + 2
# ]
# for i in range(self.metadata["amount_water_holes"]):
# water_holes_coords[f"water_hole_{i}"] = state[
# wh_offset + i : wh_offset + i + 2
# ]

# keys = (
# list(agent_coords) + list(grass_patches_coords) + list(water_holes_coords)
# )
# StateTuple = namedtuple("StateTuple", {k: np.ndarray for k in keys})
# return StateTuple(**agent_coords, **grass_patches_coords, **water_holes_coords)

"""
This API is intended primarily as input for the neural network.
Expand Down Expand Up @@ -642,14 +645,24 @@ def __init__(
def reset(
self, seed: Optional[int] = None, options=None, *args, **kwargs
) -> Tuple[Dict[AgentId, Observation], Dict[AgentId, Info]]:
global reset_count

reset_count += 1
# print("env reset_count: " + str(reset_count))
if self._pre_reset_callback2 is not None:
(allow_reset, seed, options, args, kwargs) = self._pre_reset_callback2(
seed, options, *args, **kwargs
)
if not allow_reset:
return

observations, infos = GridworldZooParallelEnv.reset(
self, seed=seed, options=options, *args, **kwargs
)

print(
"trial_no: "
+ str(GridworldZooParallelEnv.get_trial_no(self))
+ " episode_no: "
+ str(GridworldZooParallelEnv.get_episode_no(self))
)

infos = self.format_infos(infos)
self._last_infos = infos
# transform observations
Expand All @@ -659,7 +672,12 @@ def reset(
if self._override_infos:
infos = {agent: {} for agent in infos.keys()}

return self.observations2, self.filter_infos(infos)
result = (self.observations2, self.filter_infos(infos))

if self._post_reset_callback2 is not None:
self._post_reset_callback2(*result, seed, options, *args, **kwargs)

return result

def step(self, actions: Dict[str, Action]) -> Step:
"""step(action) takes in an action for each agent and should return the
Expand All @@ -674,10 +692,17 @@ def step(self, actions: Dict[str, Action]) -> Step:
{<agent_name>: <agent_action or None if agent is done>}
"""
logger.debug("debug actions", actions)

if self._pre_step_callback2 is not None:
actions = self._pre_step_callback2(actions)

# If a user passes in actions with no agents,
# then just return empty observations, etc.
if not actions:
return {}, {}, {}, {}, {}
result = {}, {}, {}, {}, {}
if self._post_step_callback is not None:
self._post_step_callback(actions, *result)
return result

(
observations,
Expand Down Expand Up @@ -722,14 +747,19 @@ def step(self, actions: Dict[str, Action]) -> Step:
truncateds,
self.filter_infos(infos),
)
return (
result = (
self.observations2,
rewards2,
terminateds,
truncateds,
self.filter_infos(infos),
)

if self._post_step_callback2 is not None:
self._post_step_callback2(actions, *result)

return result


class SavannaGridworldSequentialEnv(GridworldZooBaseEnv, GridworldZooAecEnv):
def __init__(
Expand Down Expand Up @@ -783,7 +813,7 @@ def infos(
self,
):
"""Needed for tests.
Zoo is unable to compare infos unless they have simple structure.
Note, Zoo is unable to compare infos unless they have simple structure.
"""
infos = GridworldZooAecEnv.infos.fget(
self
Expand Down Expand Up @@ -821,13 +851,22 @@ def observe_info(self, agent):
def reset(
self, seed: Optional[int] = None, options=None, *args, **kwargs
) -> Tuple[Dict[AgentId, Observation], Dict[AgentId, Info]]:
global reset_count

reset_count += 1
# print("env reset_count: " + str(reset_count))
if self._pre_reset_callback2 is not None:
(allow_reset, seed, options, args, kwargs) = self._pre_reset_callback2(
seed, options, *args, **kwargs
)
if not allow_reset:
return # TODO!!! return value

GridworldZooAecEnv.reset(self, seed=seed, options=options, *args, **kwargs)

print(
"trial_no: "
+ str(GridworldZooParallelEnv.get_trial_no(self))
+ " episode_no: "
+ str(GridworldZooParallelEnv.get_episode_no(self))
)

# observe observations, transform observations
infos = {}
for agent in self.possible_agents:
Expand All @@ -851,7 +890,12 @@ def reset(
if self._override_infos:
infos = {agent: {} for agent in infos.keys()}

return self.observations2, self.filter_infos(infos)
result = (self.observations2, self.filter_infos(infos))

if self._post_reset_callback2 is not None:
self._post_reset_callback2(*result, seed, options, *args, **kwargs)

return result

def last(self, observe=True):
"""Returns observation, cumulative reward, terminated, truncated, info for the
Expand Down Expand Up @@ -910,6 +954,9 @@ def step_single_agent(self, action: Action):

agent = self.agent_selection

if self._pre_step_callback2 is not None:
action = self._pre_step_callback2(agent, action)

# need to set current step rewards to zero for other agents
# the agent should be visible in .rewards after it dies
# (until its "dead step"), but during next agent's step
Expand Down Expand Up @@ -969,14 +1016,19 @@ def step_single_agent(self, action: Action):
truncated,
self.filter_info(agent, info),
)
return (
result = (
observation2,
reward2,
terminated,
truncated,
self.filter_info(agent, info),
)

if self._post_step_callback2 is not None:
self._post_step_callback2(agent, action, *result)

return result

def step_multiple_agents(self, actions: Dict[str, Action]) -> Step:
"""step(action) takes in an action for each agent and should return the
- observations
Expand Down

0 comments on commit a2a4223

Please sign in to comment.