Skip to content

Commit

Permalink
feat: added prepare obs to training script
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi committed Apr 17, 2024
1 parent 8d07f4d commit d3966f9
Show file tree
Hide file tree
Showing 24 changed files with 117 additions and 153 deletions.
8 changes: 4 additions & 4 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sheeprl.algos.a2c.agent import A2CAgent, build_agent
from sheeprl.algos.a2c.loss import policy_loss, value_loss
from sheeprl.algos.a2c.utils import test
from sheeprl.algos.a2c.utils import prepare_obs, test
from sheeprl.data import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -236,7 +236,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Sample an action given the observation received by the environment
# This calls the `forward` method of the PyTorch module, escaping from Fabric
# because we don't want this to be a synchronization point
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
actions, _, values = player(torch_obs)
if is_continuous:
real_actions = torch.cat(actions, -1).cpu().numpy()
Expand Down Expand Up @@ -272,7 +272,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Update the step data
step_data["dones"] = dones[np.newaxis]
step_data["values"] = values.cpu().numpy()[np.newaxis]
step_data["actions"] = actions[np.newaxis]
step_data["actions"] = actions.reshape(1, cfg.env.num_envs, -1)
step_data["rewards"] = rewards[np.newaxis]
if cfg.buffer.memmap:
step_data["returns"] = np.zeros_like(rewards, shape=(1, *rewards.shape))
Expand Down Expand Up @@ -304,7 +304,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Estimate returns with GAE (https://arxiv.org/abs/1506.02438)
with torch.inference_mode():
torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys}
torch_obs = prepare_obs(fabric, next_obs, num_envs=cfg.env.num_envs)
next_values = player.get_values(torch_obs)
returns, advantages = gae(
local_data["rewards"].to(torch.float64),
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"}


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *args, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v[np.newaxis]).to(fabric.device).float() for k, v in obs.items()}
def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], *, num_envs: int = 1, **kwargs) -> Dict[str, Tensor]:
torch_obs = {k: torch.from_numpy(v.copy()).to(fabric.device).float().reshape(num_envs, -1) for k, v in obs.items()}
return torch_obs


Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sheeprl.algos.dreamer_v1.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.dreamer_v2.utils import prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -574,16 +574,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from sheeprl.algos.dreamer_v2.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -599,16 +599,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
12 changes: 7 additions & 5 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ def compute_lambda_values(
return torch.cat(list(reversed(lv)), dim=0)


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str] = []) -> Dict[str, Tensor]:
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).view(1, *v.shape).float()
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k][None, ...] / 255 - 0.5
torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) / 255 - 0.5
else:
torch_obs[k] = torch_obs[k][None, ...]
torch_obs[k] = torch_obs[k].view(1, num_envs, -1)

return torch_obs

Expand Down Expand Up @@ -143,7 +145,7 @@ def test(
player.init_states()
while not done:
# Act greedly through the environment
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
torch_obs = prepare_obs(fabric, o, cnn_keys=cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
Expand Down
12 changes: 4 additions & 8 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from sheeprl.algos.dreamer_v3.agent import WorldModel, build_agent
from sheeprl.algos.dreamer_v3.loss import reconstruction_loss
from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test
from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, prepare_obs, test
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.envs.wrappers import RestartOnException
from sheeprl.utils.distribution import (
Expand Down Expand Up @@ -566,15 +566,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
preprocessed_obs = {}
for k, v in obs.items():
preprocessed_obs[k] = torch.as_tensor(v[np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
preprocessed_obs[k] = preprocessed_obs[k] / 255.0 - 0.5
mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
Expand Down
12 changes: 7 additions & 5 deletions sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,16 @@ def compute_lambda_values(
return ret


def prepare_obs(fabric: Fabric, obs: Dict[str, np.ndarray], cnn_keys: Sequence[str] = []) -> Dict[str, Tensor]:
def prepare_obs(
fabric: Fabric, obs: Dict[str, np.ndarray], *, cnn_keys: Sequence[str] = [], num_envs: int = 1, **kwargs
) -> Dict[str, Tensor]:
torch_obs = {}
for k, v in obs.items():
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).view(1, *v.shape).float()
torch_obs[k] = torch.from_numpy(v.copy()).to(fabric.device).float()
if k in cnn_keys:
torch_obs[k] = torch_obs[k][None, ...] / 255 - 0.5
torch_obs[k] = torch_obs[k].view(1, num_envs, -1, *v.shape[-2:]) / 255 - 0.5
else:
torch_obs[k] = torch_obs[k][None, ...]
torch_obs[k] = torch_obs[k].view(1, num_envs, -1)

return torch_obs

Expand Down Expand Up @@ -118,7 +120,7 @@ def test(
player.init_states()
while not done:
# Act greedly through the environment
torch_obs = prepare_obs(fabric, o, cfg.algo.cnn_keys.encoder)
torch_obs = prepare_obs(fabric, o, cnn_keys=cfg.algo.cnn_keys.encoder)
real_actions = player.get_actions(
torch_obs, greedy, {k: v for k, v in torch_obs.items() if k.startswith("mask")}
)
Expand Down
11 changes: 6 additions & 5 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from sheeprl.algos.droq.agent import DROQAgent, build_agent
from sheeprl.algos.sac.loss import entropy_loss, policy_loss
from sheeprl.algos.sac.sac import test
from sheeprl.algos.sac.utils import prepare_obs, test
from sheeprl.data.buffers import ReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -305,9 +305,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
else:
with torch.inference_mode():
# Sample an action given the observation received by the environment
actions = player(torch.from_numpy(obs).to(device))
torch_obs = prepare_obs(fabric, o, num_envs=cfg.env.num_envs)
actions = player(torch_obs)
actions = actions.cpu().numpy()
next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))
o, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape))

if cfg.metric.log_level > 0 and "final_info" in infos:
for i, agent_ep_info in enumerate(infos["final_info"]):
Expand All @@ -320,14 +321,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print(f"Rank-0: policy_step={policy_step}, reward_env_{i}={ep_rew[-1]}")

# Save the real next observation
real_next_obs = copy.deepcopy(next_obs)
real_next_obs = copy.deepcopy(o)
if "final_observation" in infos:
for idx, final_obs in enumerate(infos["final_observation"]):
if final_obs is not None:
for k, v in final_obs.items():
real_next_obs[k][idx] = v

next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32)
next_obs = np.concatenate([o[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(np.float32)
real_next_obs = np.concatenate([real_next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1).astype(
np.float32
)
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sheeprl.algos.dreamer_v1.agent import WorldModel
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.dreamer_v2.utils import prepare_obs, test
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
Expand Down Expand Up @@ -598,16 +598,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchmetrics import SumMetric

from sheeprl.algos.dreamer_v1.dreamer_v1 import train
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.dreamer_v2.utils import prepare_obs, test
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
Expand Down Expand Up @@ -253,16 +253,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(torch_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
13 changes: 4 additions & 9 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, prepare_obs, test
from sheeprl.algos.p2e_dv2.agent import build_agent
from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer
from sheeprl.utils.env import make_env
Expand Down Expand Up @@ -735,16 +735,11 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
axis=-1,
)
else:
normalized_obs = {}
for k in obs_keys:
torch_obs = torch.as_tensor(obs[k][np.newaxis], dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_obs = torch_obs / 255 - 0.5
normalized_obs[k] = torch_obs
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
torch_obs = prepare_obs(fabric, obs, cnn_keys=cfg.algo.cnn_keys.encoder, num_envs=cfg.env.num_envs)
mask = {k: v for k, v in torch_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
real_actions = actions = player.get_actions(torch_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
Loading

0 comments on commit d3966f9

Please sign in to comment.