Skip to content

Commit

Permalink
Various cleanup of incorrect comments and old code, also minor refact…
Browse files Browse the repository at this point in the history
…oring
  • Loading branch information
Biological Compatibility Benchmarks committed Dec 7, 2024
1 parent 78debe1 commit 81ed30d
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 69 deletions.
1 change: 0 additions & 1 deletion aintelope/agents/abstract_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,5 @@ def update(
score: float = 0.0,
done: bool = False,
test_mode: bool = False,
save_path: Optional[str] = None, # TODO: this is unused right now
) -> list:
...
28 changes: 14 additions & 14 deletions aintelope/agents/example_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from aintelope.environments.savanna_safetygrid import ACTION_RELATIVE_COORDINATE_MAP

from aintelope.agents.q_agent import QAgent
from aintelope.agents.abstract_agent import Agent
from aintelope.aintelope_typing import ObservationFloat, PettingZooEnv
from aintelope.training.dqn_training import Trainer

Expand All @@ -31,7 +31,7 @@
logger = logging.getLogger("aintelope.agents.example_agent")


class ExampleAgent(QAgent):
class ExampleAgent(Agent):
"""Example agent class"""

def __init__(
Expand All @@ -40,15 +40,22 @@ def __init__(
trainer: Trainer,
env: Environment = None,
cfg: DictConfig = None,
**kwargs,
) -> None:
super().__init__(
agent_id=agent_id,
trainer=trainer,
)
self.id = agent_id
self.trainer = trainer
self.env = env
self.cfg = cfg
self.done = False
self.last_action = None

def reset(self, state, info, env_class) -> None:
"""Resets self and updates the state."""
super().reset(state, info, env_class)
self.done = False
self.last_action = None
self.state = state
self.info = info
self.env_class = env_class

def get_action(
self,
Expand All @@ -64,11 +71,6 @@ def get_action(
"""Given an observation, ask your net what to do. State is needed to be
given here as other agents have changed the state!
Args:
net: pytorch Module instance, the model
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
action (Optional[int]): index of action
"""
Expand Down Expand Up @@ -100,7 +102,6 @@ def update(
score: float = 0.0,
done: bool = False,
test_mode: bool = False,
save_path: Optional[str] = None, # TODO: this is unused right now
) -> list:
"""
Takes observations and updates trainer on perceived experiences.
Expand All @@ -110,7 +111,6 @@ def update(
observation: Tuple[ObservationArray, ObservationArray]
score: Only baseline uses score as a reward
done: boolean whether run is done
save_path: str
Returns:
agent_id (str): same as elsewhere ("agent_0" among them)
state (Tuple[npt.NDArray[ObservationFloat], npt.NDArray[ObservationFloat]]): input for the net
Expand Down
11 changes: 3 additions & 8 deletions aintelope/agents/q_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import numpy as np
import numpy.typing as npt

from aintelope.agents import Agent
from aintelope.agents.abstract_agent import Agent
from aintelope.aintelope_typing import ObservationFloat, PettingZooEnv
from aintelope.training.dqn_training import Trainer

Expand All @@ -36,6 +36,7 @@ def __init__(
trainer: Trainer,
env: Environment = None,
cfg: DictConfig = None,
**kwargs,
) -> None:
self.id = agent_id
self.trainer = trainer
Expand Down Expand Up @@ -65,11 +66,6 @@ def get_action(
"""Given an observation, ask your net what to do. State is needed to be
given here as other agents have changed the state!
Args:
net: pytorch Module instance, the model
epsilon: value to determine likelihood of taking a random action
device: current device
Returns:
action (Optional[int]): index of action
"""
Expand Down Expand Up @@ -110,7 +106,6 @@ def update(
score: float = 0.0,
done: bool = False,
test_mode: bool = False,
save_path: Optional[str] = None, # TODO: this is unused right now
) -> list:
"""
Takes observations and updates trainer on perceived experiences.
Expand All @@ -120,7 +115,6 @@ def update(
observation: Tuple[ObservationArray, ObservationArray]
score: Only baseline uses score as a reward
done: boolean whether run is done
save_path: str
Returns:
agent_id (str): same as elsewhere ("agent_0" among them)
state (Tuple[npt.NDArray[ObservationFloat], npt.NDArray[ObservationFloat]]): input for the net
Expand All @@ -133,6 +127,7 @@ def update(
assert self.last_action is not None

next_state = observation

# TODO

event = [self.id, self.state, self.last_action, score, done, next_state]
Expand Down
9 changes: 6 additions & 3 deletions aintelope/analytics/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
import math
import numpy as np
import pandas as pd
from matplotlib import cm

from matplotlib import pyplot as plt

import yaml

"""
Expand Down Expand Up @@ -54,7 +55,9 @@ def filter_train_and_test_events(
score_dimensions = ["Reward"] + score_dimensions
events[score_dimensions] = events[score_dimensions].astype(float)

if group_by_pipeline_cycle:
if (
group_by_pipeline_cycle
): # TODO: perhaps this branch is not needed and the "IsTest" column is sufficient in all cases?
train_events = events[events["Pipeline cycle"] < num_train_pipeline_cycles]
test_events = events[events["Pipeline cycle"] >= num_train_pipeline_cycles]
else:
Expand Down Expand Up @@ -238,7 +241,7 @@ def plot_performance(
] = True # ensure that plot labels fit to the image and do not overlap

# fig = plt.figure()
fig, subplots = plt.subplots(3)
fig, subplots = plt.subplots(len(plot_datas))

linewidth = 0.75 # TODO: config

Expand Down
18 changes: 0 additions & 18 deletions aintelope/analytics/recording.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,24 +94,6 @@ def read_checkpoints(checkpoint_dir):
### Old stuff, not in use, but should belong here:


def process_events(
events_df: pd.DataFrame,
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Function to convert the agent events dataframe into individual dataframe for
agent position, grass and water locations.
"""
state_df = pd.DataFrame(events_df.state.to_list())
agent_df = pd.DataFrame(columns=["x", "y"], data=state_df.agent_coords.to_list())
grass_columns = [c for c in list(state_df) if c.startswith("grass")]
grass_df = state_df[grass_columns].applymap(lambda x: tuple(x))
grass_df = pd.DataFrame(columns=["x", "y"], data=set(grass_df.stack().to_list()))
water_columns = [c for c in list(state_df) if c.startswith("water")]
water_df = state_df[water_columns].applymap(lambda x: tuple(x))
water_df = pd.DataFrame(columns=["x", "y"], data=set(water_df.stack().to_list()))

return agent_df, grass_df, water_df


def plot_events(agent, style: str = "thickness", color: str = "viridis") -> Figure:
"""
Docstring missing, these are old functions I'm unsure are in use atm.
Expand Down
35 changes: 18 additions & 17 deletions aintelope/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,24 @@ def run_experiment(
else:
raise NotImplementedError(f"Unknown environment type {type(env)}")

events = pd.DataFrame(
columns=[
"Run_id",
"Pipeline cycle",
"Episode",
"Trial",
"Step",
"IsTest",
"Agent_id",
"State",
"Action",
"Reward",
"Done",
"Next_state",
]
+ (score_dimensions if isinstance(env, GridworldZooBaseEnv) else ["Score"])
)

# Common trainer for each agent's models
trainer = Trainer(cfg)

Expand Down Expand Up @@ -148,23 +166,6 @@ def run_experiment(
# agents.play_step(self.net, epsilon=1.0)

# Main loop
events = pd.DataFrame(
columns=[
"Run_id",
"Pipeline cycle",
"Episode",
"Trial",
"Step",
"IsTest",
"Agent_id",
"State",
"Action",
"Reward",
"Done",
"Next_state",
]
+ (score_dimensions if isinstance(env, GridworldZooBaseEnv) else ["Score"])
)

model_needs_saving = (
False # if no training episodes are specified then do not save models
Expand Down
26 changes: 18 additions & 8 deletions aintelope/training/simple_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def run_episode(full_params: Dict) -> None:
observation = observations[agent.id]
info = infos[agent.id]
actions[agent.id] = agent.get_action(
observation, info, step, trial=0, episode=0, pipeline_cycle=0
observation=observation,
info=info,
step=step,
trial=0,
episode=0,
pipeline_cycle=0,
)

logger.debug("debug actions", actions)
Expand Down Expand Up @@ -225,9 +230,9 @@ def run_episode(full_params: Dict) -> None:
else:
# action = action_space(agent.id).sample()
action = agent.get_action(
observation,
info,
step,
observation=observation,
info=info,
step=step,
trial=0,
episode=0,
pipeline_cycle=0,
Expand Down Expand Up @@ -295,7 +300,12 @@ def run_episode(full_params: Dict) -> None:
observation = observations[agent.id]
info = infos[agent.id]
actions[agent.id] = agent.get_action(
observation, info, step, trial=0, episode=0, pipeline_cycle=0
observation=observation,
info=info,
step=step,
trial=0,
episode=0,
pipeline_cycle=0,
)

logger.debug("debug actions", actions)
Expand Down Expand Up @@ -337,9 +347,9 @@ def run_episode(full_params: Dict) -> None:
else:
# action = action_space(agent.id).sample()
action = agent.get_action(
observation,
info,
step,
observation=observation,
info=info,
step=step,
trial=0,
episode=0,
pipeline_cycle=0,
Expand Down

0 comments on commit 81ed30d

Please sign in to comment.