Skip to content

Commit

Permalink
Fix/a2c evaluation (#195)
Browse files Browse the repository at this point in the history
* Add A2C evaluation function + save config.yaml

* Better evaluation_registry explainability
  • Loading branch information
belerico authored Jan 21, 2024
1 parent 5163fd1 commit 0e88153
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 12 deletions.
1 change: 1 addition & 0 deletions sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from sheeprl.algos.sac import sac_decoupled # noqa: F401
from sheeprl.algos.sac_ae import sac_ae # noqa: F401

from sheeprl.algos.a2c import evaluate as a2c_evaluate # noqa: F401, isort:skip
from sheeprl.algos.dreamer_v1 import evaluate as dreamer_v1_evaluate # noqa: F401, isort:skip
from sheeprl.algos.dreamer_v2 import evaluate as dreamer_v2_evaluate # noqa: F401, isort:skip
from sheeprl.algos.dreamer_v3 import evaluate as dreamer_v3_evaluate # noqa: F401, isort:skip
Expand Down
5 changes: 4 additions & 1 deletion sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import gae
from sheeprl.utils.utils import gae, save_configs


def train(
Expand Down Expand Up @@ -175,6 +175,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# the optimizer and set up it with Fabric
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all")

if fabric.is_global_zero:
save_configs(cfg, log_dir)

# Create a metric aggregator to log the metrics
aggregator = None
if not MetricAggregator.disabled:
Expand Down
58 changes: 58 additions & 0 deletions sheeprl/algos/a2c/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from __future__ import annotations

from typing import Any, Dict

import gymnasium as gym
from lightning import Fabric

from sheeprl.algos.a2c.agent import build_agent
from sheeprl.algos.a2c.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation


@register_evaluation(algorithms="a2c")
def evaluate_a2c(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
logger = get_logger(fabric, cfg)
if logger and fabric.is_global_zero:
fabric._loggers = [logger]
fabric.logger.log_hyperparams(cfg)
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)
fabric.print(f"Log dir: {log_dir}")

env = make_env(
cfg,
cfg.seed,
0,
log_dir,
"test",
vector_env_idx=0,
)()
observation_space = env.observation_space

if not isinstance(observation_space, gym.spaces.Dict):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if len(cfg.algo.mlp_keys.encoder) == 0:
raise RuntimeError("You should specify at least one MLP key for the encoder: `algo.mlp_keys.encoder=[state]`")
for k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
if k in observation_space.keys() and len(observation_space[k].shape) > 1:
raise ValueError(
"Only environments with vector-only observations are supported by the A2C agent. "
f"The observation with key '{k}' has shape {observation_space[k].shape}. "
f"Provided environment: {cfg.env.id}"
)
if cfg.metric.log_level > 0:
fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder)

is_continuous = isinstance(env.action_space, gym.spaces.Box)
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.shape
if is_continuous
else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n])
)
# Create the actor and critic models
agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
test(agent, fabric, cfg, log_dir)
17 changes: 11 additions & 6 deletions sheeprl/available_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,24 @@
table.add_column("Decoupled")
table.add_column("Evaluated by")

# print(evaluation_registry)

for module, implementations in algorithm_registry.items():
for algo in implementations:
evaluation_entrypoint = "Undefined"
for evaluation in evaluation_registry[module]:
if algo["name"] == evaluation["name"]:
evaluation_entrypoint = evaluation["entrypoint"]
break
evaluated_by = "Undefined"
if module in evaluation_registry:
for evaluation in evaluation_registry[module]:
if algo["name"] == evaluation["name"]:
evaluation_file = evaluation["evaluation_file"]
evaluation_entrypoint = evaluation["entrypoint"]
evaluated_by = module + "." + evaluation_file + "." + evaluation_entrypoint
break
table.add_row(
module,
algo["name"],
algo["entrypoint"],
str(algo["decoupled"]),
module + ".evaluate." + evaluation_entrypoint,
evaluated_by,
)

console = Console()
Expand Down
19 changes: 14 additions & 5 deletions sheeprl/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) ->
entrypoint = fn.__name__
module_split = fn.__module__.split(".")
module = ".".join(module_split[:-1])
evaluation_file = module_split[-1]
if isinstance(algorithms, str):
algorithms = [algorithms]
# Check that the algorithms which we want to register an evaluation function for
Expand All @@ -55,7 +56,7 @@ def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) ->
registered_algo_names = {algo["name"] for algo in registered_algos}
if len(set(algorithms) - registered_algo_names) > 0:
raise ValueError(
f"You are trying to register the evaluation function `{module+'.'+entrypoint}` "
f"You are trying to register the evaluation function `{module+'.'+evaluation_file+'.'+entrypoint}` "
f"for algorithms which have not been registered for the module `{module}`!\n"
f"Registered algorithms: {', '.join(registered_algo_names)}\n"
f"Specified algorithms: {', '.join(algorithms)}"
Expand All @@ -64,17 +65,25 @@ def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) ->
if registered_evals is None:
evaluation_registry[module] = []
for algorithm in algorithms:
evaluation_registry[module].append({"name": algorithm, "entrypoint": entrypoint})
evaluation_registry[module].append(
{"name": algorithm, "evaluation_file": evaluation_file, "entrypoint": entrypoint}
)
else:
for registered_eval in registered_evals:
if registered_eval["name"] in algorithms:
raise ValueError(
f"Cannot register the evaluate function `{module+'.'+entrypoint}` "
f"Cannot register the evaluate function `{module+'.'+evaluation_file+'.'+entrypoint}` "
f"for the algorithm `{registered_eval['name']}`: "
f"the evaluation function `{module+'.'+registered_eval['entrypoint']}` has already "
"the evaluation function "
f"`{module+'.'+registered_eval['evaluation_file']+'.'+registered_eval['entrypoint']}` has already "
f"been registered for the algorithm named `{registered_eval['name']}` in the module `{module}`!"
)
evaluation_registry[module].extend([{"name": algorithm, "entrypoint": entrypoint} for algorithm in algorithms])
evaluation_registry[module].extend(
[
{"name": algorithm, "evaluation_file": evaluation_file, "entrypoint": entrypoint}
for algorithm in algorithms
]
)

# add the decorated function to __all__ in algorithm
mod = sys.modules[fn.__module__]
Expand Down

0 comments on commit 0e88153

Please sign in to comment.