From a62a929bd2d0c8c9f6ef5fa05536bbe9cca3495f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 5 Nov 2022 16:08:04 -0400 Subject: [PATCH 1/3] Handle truncation properly with PPO --- .../ppo_atari_envpool_xla_jax_truncation.py | 484 ++++++++++++++++++ 1 file changed, 484 insertions(+) create mode 100644 cleanrl/ppo_atari_envpool_xla_jax_truncation.py diff --git a/cleanrl/ppo_atari_envpool_xla_jax_truncation.py b/cleanrl/ppo_atari_envpool_xla_jax_truncation.py new file mode 100644 index 000000000..9d526ffc7 --- /dev/null +++ b/cleanrl/ppo_atari_envpool_xla_jax_truncation.py @@ -0,0 +1,484 @@ +""" +Handle truncation +python -i ppo_atari_envpool_xla_jax_truncation.py --env-id Breakout-v5 --num-envs 1 --num-steps 8 --num-minibatches 2 --update-epochs 2 + +>>> storage.dones.flatten() +DeviceArray([0., 0., 0., 0., 0., 0., 1., 0.], dtype=float32) +>>> storage.truncations.flatten() +DeviceArray([0., 0., 0., 0., 0., 0., 1., 0.], dtype=float32) +>>> storage.rewards.flatten() +DeviceArray([0., 0., 1., 0., 0., 0., 0., 0.], dtype=float32) +>>> storage.values.flatten() +DeviceArray([ 0.00226192, 0.00071621, 0.00114149, -0.00414939, + -0.00838596, -0.01181885, -0.01047847, 0.00127411], dtype=float32) + +# bootstrap value +>>> jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards).flatten() +DeviceArray([ 0. , 0. , 1. , 0. , + 0. , 0. , -0.01047847, 0. ], dtype=float32) + +""" + + +# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_xla_jaxpy +import argparse +import os +import random +import time +from distutils.util import strtobool +from typing import Sequence + +os.environ[ + "XLA_PYTHON_CLIENT_MEM_FRACTION" +] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 + +import envpool +import flax +import flax.linen as nn +import gym +import jax +import jax.numpy as jnp +import numpy as np +import optax +from flax.linen.initializers import constant, orthogonal +from flax.training.train_state import TrainState +from torch.utils.tensorboard import SummaryWriter + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser() + parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), + help="the name of this experiment") + parser.add_argument("--seed", type=int, default=1, + help="seed of the experiment") + parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, `torch.backends.cudnn.deterministic=False`") + parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="if toggled, cuda will be enabled by default") + parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="if toggled, this experiment will be tracked with Weights and Biases") + parser.add_argument("--wandb-project-name", type=str, default="cleanRL", + help="the wandb's project name") + parser.add_argument("--wandb-entity", type=str, default=None, + help="the entity (team) of wandb's project") + parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, + help="whether to capture videos of the agent performances (check out `videos` folder)") + + # Algorithm specific arguments + parser.add_argument("--env-id", type=str, default="Pong-v5", + help="the id of the environment") + parser.add_argument("--total-timesteps", type=int, default=10000000, + help="total timesteps of the experiments") + parser.add_argument("--learning-rate", type=float, default=2.5e-4, + help="the learning rate of the optimizer") + parser.add_argument("--num-envs", type=int, default=8, + help="the number of parallel game environments") + parser.add_argument("--num-steps", type=int, default=128, + help="the number of steps to run in each environment per policy rollout") + parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggle learning rate annealing for policy and value networks") + parser.add_argument("--gamma", type=float, default=0.99, + help="the discount factor gamma") + parser.add_argument("--gae-lambda", type=float, default=0.95, + help="the lambda for the general advantage estimation") + parser.add_argument("--num-minibatches", type=int, default=4, + help="the number of mini-batches") + parser.add_argument("--update-epochs", type=int, default=4, + help="the K epochs to update the policy") + parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, + help="Toggles advantages normalization") + parser.add_argument("--clip-coef", type=float, default=0.1, + help="the surrogate clipping coefficient") + parser.add_argument("--ent-coef", type=float, default=0.01, + help="coefficient of the entropy") + parser.add_argument("--vf-coef", type=float, default=0.5, + help="coefficient of the value function") + parser.add_argument("--max-grad-norm", type=float, default=0.5, + help="the maximum norm for the gradient clipping") + parser.add_argument("--target-kl", type=float, default=None, + help="the target KL divergence threshold") + args = parser.parse_args() + args.batch_size = int(args.num_envs * args.num_steps) + args.minibatch_size = int(args.batch_size // args.num_minibatches) + args.num_updates = args.total_timesteps // args.batch_size + # fmt: on + return args + + +class Network(nn.Module): + @nn.compact + def __call__(self, x): + x = jnp.transpose(x, (0, 2, 3, 1)) + x = x / (255.0) + x = nn.Conv( + 32, + kernel_size=(8, 8), + strides=(4, 4), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = nn.Conv( + 64, + kernel_size=(4, 4), + strides=(2, 2), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = nn.Conv( + 64, + kernel_size=(3, 3), + strides=(1, 1), + padding="VALID", + kernel_init=orthogonal(np.sqrt(2)), + bias_init=constant(0.0), + )(x) + x = nn.relu(x) + x = x.reshape((x.shape[0], -1)) + x = nn.Dense(512, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) + x = nn.relu(x) + return x + + +class Critic(nn.Module): + @nn.compact + def __call__(self, x): + return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) + + +class Actor(nn.Module): + action_dim: Sequence[int] + + @nn.compact + def __call__(self, x): + return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) + + +@flax.struct.dataclass +class AgentParams: + network_params: flax.core.FrozenDict + actor_params: flax.core.FrozenDict + critic_params: flax.core.FrozenDict + + +@flax.struct.dataclass +class Storage: + obs: jnp.array + actions: jnp.array + logprobs: jnp.array + dones: jnp.array + values: jnp.array + advantages: jnp.array + returns: jnp.array + rewards: jnp.array + truncations: jnp.array + + +@flax.struct.dataclass +class EpisodeStatistics: + episode_returns: jnp.array + episode_lengths: jnp.array + returned_episode_returns: jnp.array + returned_episode_lengths: jnp.array + + +if __name__ == "__main__": + args = parse_args() + run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" + if args.track: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=vars(args), + name=run_name, + monitor_gym=True, + save_code=True, + ) + writer = SummaryWriter(f"runs/{run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # TRY NOT TO MODIFY: seeding + random.seed(args.seed) + np.random.seed(args.seed) + key = jax.random.PRNGKey(args.seed) + key, network_key, actor_key, critic_key = jax.random.split(key, 4) + + # env setup + envs = envpool.make( + args.env_id, + env_type="gym", + num_envs=args.num_envs, + episodic_life=True, + reward_clip=True, + seed=args.seed, + ) + envs.num_envs = args.num_envs + envs.single_action_space = envs.action_space + envs.single_observation_space = envs.observation_space + envs.is_vector_env = True + episode_stats = EpisodeStatistics( + episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + returned_episode_returns=jnp.zeros(args.num_envs, dtype=jnp.float32), + returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), + ) + handle, recv, send, step_env = envs.xla() + + def step_env_wrappeed(episode_stats, handle, action): + handle, (next_obs, reward, next_done, info) = step_env(handle, action) + new_episode_return = episode_stats.episode_returns + info["reward"] + new_episode_length = episode_stats.episode_lengths + 1 + episode_stats = episode_stats.replace( + episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + # only update the `returned_episode_returns` if the episode is done + returned_episode_returns=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns + ), + returned_episode_lengths=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths + ), + ) + return episode_stats, handle, (next_obs, reward, next_done, info) + + assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" + + def linear_schedule(count): + # anneal learning rate linearly after one training iteration which contains + # (args.num_minibatches * args.update_epochs) gradient updates + frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates + return args.learning_rate * frac + + network = Network() + actor = Actor(action_dim=envs.single_action_space.n) + critic = Critic() + network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) + agent_state = TrainState.create( + apply_fn=None, + params=AgentParams( + network_params, + actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), + critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), + ), + tx=optax.chain( + optax.clip_by_global_norm(args.max_grad_norm), + optax.inject_hyperparams(optax.adam)( + learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 + ), + ), + ) + network.apply = jax.jit(network.apply) + actor.apply = jax.jit(actor.apply) + critic.apply = jax.jit(critic.apply) + + # ALGO Logic: Storage setup + storage = Storage( + obs=jnp.zeros((args.num_steps, args.num_envs) + envs.single_observation_space.shape), + actions=jnp.zeros((args.num_steps, args.num_envs) + envs.single_action_space.shape, dtype=jnp.int32), + logprobs=jnp.zeros((args.num_steps, args.num_envs)), + dones=jnp.zeros((args.num_steps, args.num_envs)), + values=jnp.zeros((args.num_steps, args.num_envs)), + advantages=jnp.zeros((args.num_steps, args.num_envs)), + returns=jnp.zeros((args.num_steps, args.num_envs)), + rewards=jnp.zeros((args.num_steps, args.num_envs)), + truncations=jnp.zeros((args.num_steps, args.num_envs)), + ) + + @jax.jit + def get_action_and_value( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + next_truncated: np.ndarray, + storage: Storage, + step: int, + key: jax.random.PRNGKey, + ): + """sample action, calculate value, logprob, entropy, and update storage""" + hidden = network.apply(agent_state.params.network_params, next_obs) + logits = actor.apply(agent_state.params.actor_params, hidden) + # sample action: Gumbel-softmax trick + # see https://stats.stackexchange.com/questions/359442/sampling-from-a-categorical-distribution + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey, shape=logits.shape) + action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + value = critic.apply(agent_state.params.critic_params, hidden) + storage = storage.replace( + obs=storage.obs.at[step].set(next_obs), + dones=storage.dones.at[step].set(next_done), + truncations=storage.truncations.at[step].set(next_truncated), + actions=storage.actions.at[step].set(action), + logprobs=storage.logprobs.at[step].set(logprob), + values=storage.values.at[step].set(value.squeeze()), + ) + return storage, action, key + + @jax.jit + def get_action_and_value2( + params: flax.core.FrozenDict, + x: np.ndarray, + action: np.ndarray, + mask: np.ndarray, + ): + """calculate value, logprob of supplied `action`, and entropy""" + hidden = network.apply(params.network_params, x) + logits = actor.apply(params.actor_params, hidden) + logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] + # normalize the logits https://gregorygundersen.com/blog/2020/02/09/log-sum-exp/ + logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + logits = logits.clip(min=jnp.finfo(logits.dtype).min) + + # maks out truncated states during the learning pass so that they don't affect the loss + logits = jnp.where(mask.reshape((-1, 1)) * jnp.ones((1,4)), jnp.zeros_like(logits) -1e+8, logits) + p_log_p = logits * jax.nn.softmax(logits) + entropy = -p_log_p.sum(-1) + value = critic.apply(params.critic_params, hidden).squeeze() + return logprob, entropy, value + + @jax.jit + def compute_gae( + agent_state: TrainState, + next_obs: np.ndarray, + next_done: np.ndarray, + storage: Storage, + ): + storage = storage.replace(advantages=storage.advantages.at[:].set(0.0)) + next_value = critic.apply( + agent_state.params.critic_params, network.apply(agent_state.params.network_params, next_obs) + ).squeeze() + lastgaelam = 0 + for t in reversed(range(args.num_steps)): + if t == args.num_steps - 1: + nextnonterminal = 1.0 - next_done + nextvalues = next_value + else: + nextnonterminal = 1.0 - storage.dones[t + 1] + nextvalues = storage.values[t + 1] + delta = storage.rewards[t] + args.gamma * nextvalues * nextnonterminal - storage.values[t] + lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam + storage = storage.replace(advantages=storage.advantages.at[t].set(lastgaelam)) + storage = storage.replace(returns=storage.advantages + storage.values) + return storage + + @jax.jit + def update_ppo( + agent_state: TrainState, + storage: Storage, + key: jax.random.PRNGKey, + ): + # handle truncated trajectories + storage = storage.replace(rewards=jnp.where(storage.truncations, storage.rewards + storage.values, storage.rewards)) + b_obs = storage.obs.reshape((-1,) + envs.single_observation_space.shape) + b_logprobs = storage.logprobs.reshape(-1) + b_actions = storage.actions.reshape((-1,) + envs.single_action_space.shape) + b_truncations = storage.truncations.reshape(-1) + b_advantages = storage.advantages.reshape(-1) + b_returns = storage.returns.reshape(-1) + + def ppo_loss(params, x, a, truncation_mask, logp, mb_advantages, mb_returns): + newlogprob, entropy, newvalue = get_action_and_value2(params, x, a, truncation_mask) + logratio = newlogprob - logp + ratio = jnp.exp(logratio) + approx_kl = ((ratio - 1) - logratio).mean() + + if args.norm_adv: + mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) + + # Policy loss + pg_loss1 = -mb_advantages * ratio + pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) + pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() + + # Value loss + v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() + + entropy_loss = entropy.mean() + loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef + return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) + + ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) + for _ in range(args.update_epochs): + key, subkey = jax.random.split(key) + b_inds = jax.random.permutation(subkey, args.batch_size, independent=True) + for start in range(0, args.batch_size, args.minibatch_size): + end = start + args.minibatch_size + mb_inds = b_inds[start:end] + (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( + agent_state.params, + b_obs[mb_inds], + b_actions[mb_inds], + b_truncations[mb_inds], + b_logprobs[mb_inds], + b_advantages[mb_inds], + b_returns[mb_inds], + ) + agent_state = agent_state.apply_gradients(grads=grads) + return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key + + # TRY NOT TO MODIFY: start the game + global_step = 0 + start_time = time.time() + next_obs = envs.reset() + next_done = np.zeros(args.num_envs) + next_truncated = np.zeros(args.num_envs) + + @jax.jit + def rollout(agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step): + for step in range(0, args.num_steps): + global_step += 1 * args.num_envs + storage, action, key = get_action_and_value(agent_state, next_obs, next_done, next_truncated, storage, step, key) + + # TRY NOT TO MODIFY: execute the game and log data. + episode_stats, handle, (next_obs, reward, next_done, info) = step_env_wrappeed(episode_stats, handle, action) + storage = storage.replace( + rewards=storage.rewards.at[step].set(reward), + ) + next_truncated = info["TimeLimit.truncated"] + return agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step + + for update in range(1, args.num_updates + 1): + update_time_start = time.time() + agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step = rollout( + agent_state, episode_stats, next_obs, next_done, next_truncated, storage, key, handle, global_step + ) + storage = compute_gae(agent_state, next_obs, next_done, storage) + + agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key = update_ppo( + agent_state, + storage, + key, + ) + + avg_episodic_return = np.mean(jax.device_get(episode_stats.returned_episode_returns)) + print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") + + # TRY NOT TO MODIFY: record rewards for plotting purposes + writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) + writer.add_scalar( + "charts/avg_episodic_length", np.mean(jax.device_get(episode_stats.returned_episode_lengths)), global_step + ) + writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"].item(), global_step) + writer.add_scalar("losses/value_loss", v_loss.item(), global_step) + writer.add_scalar("losses/policy_loss", pg_loss.item(), global_step) + writer.add_scalar("losses/entropy", entropy_loss.item(), global_step) + writer.add_scalar("losses/approx_kl", approx_kl.item(), global_step) + writer.add_scalar("losses/loss", loss.item(), global_step) + print("SPS:", int(global_step / (time.time() - start_time))) + writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) + writer.add_scalar( + "charts/SPS_update", int(args.num_envs * args.num_steps / (time.time() - update_time_start)), global_step + ) + + envs.close() + writer.close() From 044d9c70dff62fe86fb0b2438262b6c23acf6884 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 11 Nov 2022 20:54:08 -0500 Subject: [PATCH 2/3] quick change --- cleanrl/ppo_atari_envpool_xla_jax_truncation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo_atari_envpool_xla_jax_truncation.py b/cleanrl/ppo_atari_envpool_xla_jax_truncation.py index 9d526ffc7..a1d9cec74 100644 --- a/cleanrl/ppo_atari_envpool_xla_jax_truncation.py +++ b/cleanrl/ppo_atari_envpool_xla_jax_truncation.py @@ -340,7 +340,7 @@ def get_action_and_value2( logits = logits.clip(min=jnp.finfo(logits.dtype).min) # maks out truncated states during the learning pass so that they don't affect the loss - logits = jnp.where(mask.reshape((-1, 1)) * jnp.ones((1,4)), jnp.zeros_like(logits) -1e+8, logits) + logits = jnp.where(mask.reshape((-1, 1)) * jnp.ones((1,4)), -jnp.inf, logits) p_log_p = logits * jax.nn.softmax(logits) entropy = -p_log_p.sum(-1) value = critic.apply(params.critic_params, hidden).squeeze() From 3fee1212fe481ebd91e754e1bcda5764ff8b1b2e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 11 Nov 2022 20:54:51 -0500 Subject: [PATCH 3/3] update --- cleanrl/ppo_atari_envpool_xla_jax.py | 39 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/cleanrl/ppo_atari_envpool_xla_jax.py b/cleanrl/ppo_atari_envpool_xla_jax.py index f47d28513..9f756ad9c 100644 --- a/cleanrl/ppo_atari_envpool_xla_jax.py +++ b/cleanrl/ppo_atari_envpool_xla_jax.py @@ -210,24 +210,6 @@ class EpisodeStatistics: returned_episode_lengths=jnp.zeros(args.num_envs, dtype=jnp.int32), ) handle, recv, send, step_env = envs.xla() - - def step_env_wrappeed(episode_stats, handle, action): - handle, (next_obs, reward, next_done, info) = step_env(handle, action) - new_episode_return = episode_stats.episode_returns + info["reward"] - new_episode_length = episode_stats.episode_lengths + 1 - episode_stats = episode_stats.replace( - episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), - episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), - # only update the `returned_episode_returns` if the episode is done - returned_episode_returns=jnp.where( - info["terminated"] + info["TimeLimit.truncated"], new_episode_return, episode_stats.returned_episode_returns - ), - returned_episode_lengths=jnp.where( - info["terminated"] + info["TimeLimit.truncated"], new_episode_length, episode_stats.returned_episode_lengths - ), - ) - return episode_stats, handle, (next_obs, reward, next_done, info) - assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" def linear_schedule(count): @@ -405,7 +387,26 @@ def rollout(agent_state, episode_stats, next_obs, next_done, storage, key, handl storage, action, key = get_action_and_value(agent_state, next_obs, next_done, storage, step, key) # TRY NOT TO MODIFY: execute the game and log data. - episode_stats, handle, (next_obs, reward, next_done, _) = step_env_wrappeed(episode_stats, handle, action) + handle, (next_obs, reward, next_done, info) = step_env(handle, action) + + # record episodic statistics + new_episode_return = episode_stats.episode_returns + info["reward"] + new_episode_length = episode_stats.episode_lengths + 1 + episode_stats = episode_stats.replace( + episode_returns=(new_episode_return) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + episode_lengths=(new_episode_length) * (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]), + # only update the `returned_episode_returns` if the episode is done + returned_episode_returns=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], + new_episode_return, + episode_stats.returned_episode_returns, + ), + returned_episode_lengths=jnp.where( + info["terminated"] + info["TimeLimit.truncated"], + new_episode_length, + episode_stats.returned_episode_lengths, + ), + ) storage = storage.replace(rewards=storage.rewards.at[step].set(reward)) return agent_state, episode_stats, next_obs, next_done, storage, key, handle, global_step