Skip to content

Commit

Permalink
feat: added prepare obs to all the algorithms (#267)
Browse files Browse the repository at this point in the history
* feat: added prepare obs to all the algorithms

* feat: added prepare obs to training script

* feat: variable renaming
  • Loading branch information
michele-milesi authored Apr 20, 2024
1 parent e692606 commit 5298755
Show file tree
Hide file tree
Showing 26 changed files with 251 additions and 303 deletions.
51 changes: 27 additions & 24 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -715,11 +715,16 @@ where `log_models`, `test` and `normalize_obs` have to be defined in the `my_awe
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

from sheeprl.utils.env import make_env
from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

Expand All @@ -729,43 +734,41 @@ if TYPE_CHECKING:
from mlflow.models.model import ModelInfo


def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k in obs.keys():
torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:])
else:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1)
return normalize_obs(torch_obs, cnn_keys, obs.keys())


@torch.no_grad()
def test(agent: SOTAAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs
obs = env.reset(seed=cfg.seed)[0]

while not done:
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder)

# Act greedly through the environment
if agent.is_continuous:
actions = torch.cat(agent.get_greedy_actions(obs), dim=-1)
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1)
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
49 changes: 24 additions & 25 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -713,59 +713,58 @@ where `log_models`, `test` and `normalize_obs` have to be defined in the `sheepr
```python
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any, Dict, Sequence

import gymnasium as gym
import numpy as np
import torch
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

from sheeprl.algos.sota.agent import SOTAAgentPlayer
from sheeprl.utils.env import make_env
from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE
from sheeprl.utils.utils import unwrap_fabric

if TYPE_CHECKING:
from mlflow.models.model import ModelInfo


def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k in obs.keys():
torch_obs[k] = torch.from_numpy(obs[k].copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1, *torch_obs[k].shape[-2:])
else:
torch_obs[k] = torch_obs[k].reshape(num_envs, -1)
return normalize_obs(torch_obs, cnn_keys, obs.keys())


@torch.no_grad()
def test(agent: SOTAAgentPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs
obs = env.reset(seed=cfg.seed)[0]

while not done:
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
if agent.is_continuous:
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder + cfg.algo.cnn_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
8 changes: 4 additions & 4 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sheeprl.algos.a2c.agent import A2CAgent, build_agent
from sheeprl.algos.a2c.loss import policy_loss, value_loss
from sheeprl.algos.a2c.utils import test
from sheeprl.algos.a2c.utils import prepare_obs, test
from sheeprl.data import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -236,7 +236,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
Expand Down Expand Up @@ -272,7 +272,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1)
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
Expand Down Expand Up @@ -304,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand Down
28 changes: 13 additions & 15 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,47 +2,45 @@

from typing import Any, Dict

import numpy as np
import torch
from lightning import Fabric
from torch import Tensor

from sheeprl.algos.ppo.agent import PPOPlayer
from sheeprl.utils.env import make_env

AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()}
return torch_obs


@torch.no_grad()
def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str):
env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)()
agent.eval()
done = False
cumulative_rew = 0
o = env.reset(seed=cfg.seed)[0]
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs
obs = env.reset(seed=cfg.seed)[0]

while not done:
# Convert observations to tensors
torch_obs = prepare_obs(fabric, obs)

# Act greedly through the environment
actions = agent.get_actions(obs, greedy=True)
actions = agent.get_actions(torch_obs, greedy=True)
if agent.actor.is_continuous:
actions = torch.cat(actions, dim=-1)
else:
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

# Single environment step
o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
obs, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape))
done = done or truncated
cumulative_rew += reward
obs = {}
for k in o.keys():
if k in cfg.algo.mlp_keys.encoder:
torch_obs = torch.from_numpy(o[k]).to(fabric.device).unsqueeze(0)
torch_obs = torch_obs.float()
obs[k] = torch_obs

if cfg.dry_run:
done = True
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sheeprl.algos.dreamer_v1.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.dreamer_v2.utils import prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -574,16 +574,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from sheeprl.algos.dreamer_v2.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -599,16 +599,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
33 changes: 19 additions & 14 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from lightning import Fabric
Expand Down Expand Up @@ -101,6 +102,20 @@ def compute_lambda_values(
return torch.cat(list(reversed(lv)), dim=0)


def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) / 255 - 0.5
else:
torch_obs[k] = torch_obs[k].view(1, num_envs, -1)

return torch_obs


@torch.no_grad()
def test(
player: "PlayerDV2" | "PlayerDV1",
Expand All @@ -125,32 +140,22 @@ def test(
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
cumulative_rew = 0
device = fabric.device
next_obs = env.reset(seed=cfg.seed)[0]
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
obs = env.reset(seed=cfg.seed)[0]
player.num_envs = 1
player.init_states()
while not done:
# Act greedly through the environment
preprocessed_obs = {}
for k, v in next_obs.items():
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device) / 255 - 0.5
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
else:
real_actions = torch.cat([real_act.argmax(dim=-1) for real_act in real_actions], dim=-1).cpu().numpy()

# Single environment step
next_obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
for k in next_obs.keys():
next_obs[k] = torch.from_numpy(next_obs[k]).view(1, *next_obs[k].shape).float()
obs, reward, done, truncated, _ = env.step(real_actions.reshape(env.action_space.shape))
done = done or truncated or cfg.dry_run
cumulative_rew += reward
fabric.print("Test - Reward:", cumulative_rew)
Expand Down
Loading

0 comments on commit 5298755

Please sign in to comment.