From ca3d0cefd246f51224e39cd1e89937723255841a Mon Sep 17 00:00:00 2001 From: Adam Zhao Date: Sun, 15 Oct 2023 12:30:39 +0800 Subject: [PATCH 1/3] clean-up --- cleanrl/ddpg_continuous_action.py | 3 +-- cleanrl/ddpg_continuous_action_jax.py | 3 +-- cleanrl/ppo.py | 9 +++++---- cleanrl/ppo_continuous_action.py | 6 ++---- cleanrl/rpo_continuous_action.py | 6 ++---- cleanrl/sac_atari.py | 2 +- cleanrl/sac_continuous_action.py | 3 ++- cleanrl/td3_continuous_action.py | 3 +-- cleanrl/td3_continuous_action_jax.py | 3 +-- 9 files changed, 16 insertions(+), 22 deletions(-) diff --git a/cleanrl/ddpg_continuous_action.py b/cleanrl/ddpg_continuous_action.py index cbeea9e6e..d42d3bc5a 100644 --- a/cleanrl/ddpg_continuous_action.py +++ b/cleanrl/ddpg_continuous_action.py @@ -149,7 +149,6 @@ def forward(self, x): "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) - video_filenames = set() # TRY NOT TO MODIFY: seeding random.seed(args.seed) @@ -205,7 +204,7 @@ def forward(self, x): writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break - # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` + # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: diff --git a/cleanrl/ddpg_continuous_action_jax.py b/cleanrl/ddpg_continuous_action_jax.py index dc29adfbe..e074acd60 100644 --- a/cleanrl/ddpg_continuous_action_jax.py +++ b/cleanrl/ddpg_continuous_action_jax.py @@ -143,7 +143,6 @@ class TrainState(TrainState): "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) - video_filenames = set() # TRY NOT TO MODIFY: seeding random.seed(args.seed) @@ -258,7 +257,7 @@ def actor_loss(params): writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break - # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` + # TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: diff --git a/cleanrl/ppo.py b/cleanrl/ppo.py index 1c4831f6d..506173175 100644 --- a/cleanrl/ppo.py +++ b/cleanrl/ppo.py @@ -78,11 +78,12 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): - env = gym.make(env_id) + if capture_video and idx == 0: + env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") + else: + env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) - if capture_video: - if idx == 0: - env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) diff --git a/cleanrl/ppo_continuous_action.py b/cleanrl/ppo_continuous_action.py index 884eefd87..b61ca73dd 100644 --- a/cleanrl/ppo_continuous_action.py +++ b/cleanrl/ppo_continuous_action.py @@ -78,15 +78,13 @@ def parse_args(): def make_env(env_id, idx, capture_video, run_name, gamma): def thunk(): - if capture_video: + if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space env = gym.wrappers.RecordEpisodeStatistics(env) - if capture_video: - if idx == 0: - env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") env = gym.wrappers.ClipAction(env) env = gym.wrappers.NormalizeObservation(env) env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) diff --git a/cleanrl/rpo_continuous_action.py b/cleanrl/rpo_continuous_action.py index dfeb84391..f7338ee7c 100644 --- a/cleanrl/rpo_continuous_action.py +++ b/cleanrl/rpo_continuous_action.py @@ -80,15 +80,13 @@ def parse_args(): def make_env(env_id, idx, capture_video, run_name, gamma): def thunk(): - if capture_video: + if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") + env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.FlattenObservation(env) # deal with dm_control's Dict observation space env = gym.wrappers.RecordEpisodeStatistics(env) - if capture_video: - if idx == 0: - env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") env = gym.wrappers.ClipAction(env) env = gym.wrappers.NormalizeObservation(env) env = gym.wrappers.TransformObservation(env, lambda obs: np.clip(obs, -10, 10)) diff --git a/cleanrl/sac_atari.py b/cleanrl/sac_atari.py index 994719321..07fa8e53c 100644 --- a/cleanrl/sac_atari.py +++ b/cleanrl/sac_atari.py @@ -246,7 +246,7 @@ def get_action(self, x): start_time = time.time() # TRY NOT TO MODIFY: start the game - obs, info = envs.reset(seed=args.seed) + obs, _ = envs.reset(seed=args.seed) for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: diff --git a/cleanrl/sac_continuous_action.py b/cleanrl/sac_continuous_action.py index 3c91399cc..a12beec64 100644 --- a/cleanrl/sac_continuous_action.py +++ b/cleanrl/sac_continuous_action.py @@ -217,7 +217,7 @@ def get_action(self, x): start_time = time.time() # TRY NOT TO MODIFY: start the game - obs, info = envs.reset(seed=args.seed) + obs, _ = envs.reset(seed=args.seed) for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: @@ -263,6 +263,7 @@ def get_action(self, x): qf2_loss = F.mse_loss(qf2_a_values, next_q_value) qf_loss = qf1_loss + qf2_loss + # optimize the model q_optimizer.zero_grad() qf_loss.backward() q_optimizer.step() diff --git a/cleanrl/td3_continuous_action.py b/cleanrl/td3_continuous_action.py index bf6564b61..837e27faf 100644 --- a/cleanrl/td3_continuous_action.py +++ b/cleanrl/td3_continuous_action.py @@ -151,7 +151,6 @@ def forward(self, x): "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) - video_filenames = set() # TRY NOT TO MODIFY: seeding random.seed(args.seed) @@ -210,7 +209,7 @@ def forward(self, x): writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break - # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` + # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: diff --git a/cleanrl/td3_continuous_action_jax.py b/cleanrl/td3_continuous_action_jax.py index 6c17ca145..3c584c6f3 100644 --- a/cleanrl/td3_continuous_action_jax.py +++ b/cleanrl/td3_continuous_action_jax.py @@ -145,7 +145,6 @@ class TrainState(TrainState): "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) - video_filenames = set() # TRY NOT TO MODIFY: seeding random.seed(args.seed) @@ -292,7 +291,7 @@ def actor_loss(params): writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break - # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` + # TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: From d1d28c7488ede97c8b8e45c5b14c0a1ee6a19fe5 Mon Sep 17 00:00:00 2001 From: Adam Zhao Date: Sun, 15 Oct 2023 12:56:57 +0800 Subject: [PATCH 2/3] update next_obs, rewards, terminations, truncations, infos --- cleanrl/c51.py | 11 ++++++----- cleanrl/c51_atari.py | 14 ++++++++------ cleanrl/c51_atari_jax.py | 14 ++++++++------ cleanrl/c51_jax.py | 11 ++++++----- cleanrl/dqn.py | 9 +++++---- cleanrl/dqn_atari.py | 12 +++++++----- cleanrl/dqn_atari_jax.py | 12 +++++++----- cleanrl/dqn_jax.py | 11 +++++------ cleanrl/ppo_continuous_action.py | 4 ++-- cleanrl/qdagger_dqn_atari_impalacnn.py | 16 ++++++++-------- cleanrl/rpo_continuous_action.py | 4 ++-- cleanrl/sac_atari.py | 4 ++-- 12 files changed, 66 insertions(+), 56 deletions(-) diff --git a/cleanrl/c51.py b/cleanrl/c51.py index 00b7018d0..3959466f1 100755 --- a/cleanrl/c51.py +++ b/cleanrl/c51.py @@ -196,7 +196,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): actions = actions.cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -208,13 +208,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -257,7 +258,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): loss.backward() optimizer.step() - # update the target network + # update target network if global_step % args.target_network_frequency == 0: target_network.load_state_dict(q_network.state_dict()) diff --git a/cleanrl/c51_atari.py b/cleanrl/c51_atari.py index 238e0e5a3..8e47bacc5 100755 --- a/cleanrl/c51_atari.py +++ b/cleanrl/c51_atari.py @@ -95,6 +95,7 @@ def thunk(): else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) + env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) env = EpisodicLifeEnv(env) @@ -104,8 +105,8 @@ def thunk(): env = gym.wrappers.ResizeObservation(env, (84, 84)) env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, 4) - env.action_space.seed(seed) + env.action_space.seed(seed) return env return thunk @@ -218,7 +219,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): actions = actions.cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -230,13 +231,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -279,7 +281,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): loss.backward() optimizer.step() - # update the target network + # update target network if global_step % args.target_network_frequency == 0: target_network.load_state_dict(q_network.state_dict()) diff --git a/cleanrl/c51_atari_jax.py b/cleanrl/c51_atari_jax.py index bebdc47b5..93c436ec5 100644 --- a/cleanrl/c51_atari_jax.py +++ b/cleanrl/c51_atari_jax.py @@ -98,6 +98,7 @@ def thunk(): else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) + env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) env = EpisodicLifeEnv(env) @@ -107,8 +108,8 @@ def thunk(): env = gym.wrappers.ResizeObservation(env, (84, 84)) env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, 4) - env.action_space.seed(seed) + env.action_space.seed(seed) return env return thunk @@ -278,7 +279,7 @@ def get_action(q_state, obs): actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -290,13 +291,14 @@ def get_action(q_state, obs): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -319,7 +321,7 @@ def get_action(q_state, obs): print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - # update the target network + # update target network if global_step % args.target_network_frequency == 0: q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) diff --git a/cleanrl/c51_jax.py b/cleanrl/c51_jax.py index d8818e713..4b65f3595 100644 --- a/cleanrl/c51_jax.py +++ b/cleanrl/c51_jax.py @@ -242,7 +242,7 @@ def loss(q_params, observations, actions, target_pmfs): actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -254,13 +254,14 @@ def loss(q_params, observations, actions, target_pmfs): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs @@ -283,7 +284,7 @@ def loss(q_params, observations, actions, target_pmfs): print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) - # update the target network + # update target network if global_step % args.target_network_frequency == 0: q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) diff --git a/cleanrl/dqn.py b/cleanrl/dqn.py index 02940644b..2aa8f9bc6 100644 --- a/cleanrl/dqn.py +++ b/cleanrl/dqn.py @@ -183,7 +183,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): actions = torch.argmax(q_values, dim=1).cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -195,13 +195,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/cleanrl/dqn_atari.py b/cleanrl/dqn_atari.py index 84f9228e4..a4c3df339 100644 --- a/cleanrl/dqn_atari.py +++ b/cleanrl/dqn_atari.py @@ -92,6 +92,7 @@ def thunk(): else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) + env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) env = EpisodicLifeEnv(env) @@ -101,8 +102,8 @@ def thunk(): env = gym.wrappers.ResizeObservation(env, (84, 84)) env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, 4) - env.action_space.seed(seed) + env.action_space.seed(seed) return env return thunk @@ -205,7 +206,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): actions = torch.argmax(q_values, dim=1).cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -217,13 +218,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/cleanrl/dqn_atari_jax.py b/cleanrl/dqn_atari_jax.py index 8d047963a..5f74d57a9 100644 --- a/cleanrl/dqn_atari_jax.py +++ b/cleanrl/dqn_atari_jax.py @@ -94,6 +94,7 @@ def thunk(): else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) + env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) env = EpisodicLifeEnv(env) @@ -103,8 +104,8 @@ def thunk(): env = gym.wrappers.ResizeObservation(env, (84, 84)) env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, 4) - env.action_space.seed(seed) + env.action_space.seed(seed) return env return thunk @@ -236,7 +237,7 @@ def mse_loss(params): actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -248,13 +249,14 @@ def mse_loss(params): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/cleanrl/dqn_jax.py b/cleanrl/dqn_jax.py index 8a5175c6e..1f0eaf623 100644 --- a/cleanrl/dqn_jax.py +++ b/cleanrl/dqn_jax.py @@ -156,9 +156,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int): assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" obs, _ = envs.reset(seed=args.seed) - q_network = QNetwork(action_dim=envs.single_action_space.n) - q_state = TrainState.create( apply_fn=q_network.apply, params=q_network.init(q_key, obs), @@ -208,7 +206,7 @@ def mse_loss(params): actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -220,13 +218,14 @@ def mse_loss(params): writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) + break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/cleanrl/ppo_continuous_action.py b/cleanrl/ppo_continuous_action.py index b61ca73dd..0845222c8 100644 --- a/cleanrl/ppo_continuous_action.py +++ b/cleanrl/ppo_continuous_action.py @@ -207,8 +207,8 @@ def get_action_and_value(self, x, action=None): logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy()) - done = np.logical_or(terminated, truncated) + next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) + done = np.logical_or(terminations, truncations) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) diff --git a/cleanrl/qdagger_dqn_atari_impalacnn.py b/cleanrl/qdagger_dqn_atari_impalacnn.py index 15b6e273b..ef7922a91 100644 --- a/cleanrl/qdagger_dqn_atari_impalacnn.py +++ b/cleanrl/qdagger_dqn_atari_impalacnn.py @@ -290,12 +290,12 @@ def kl_divergence_with_logits(target_logits, prediction_logits): else: q_values = teacher_model(torch.Tensor(obs).to(device)) actions = torch.argmax(q_values, dim=1).cpu().numpy() - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - teacher_rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + teacher_rb.add(obs, real_next_obs, actions, rewards, terminations, infos) obs = next_obs # offline training phase: train the student model using the qdagger loss @@ -377,7 +377,7 @@ def kl_divergence_with_logits(target_logits, prediction_logits): actions = torch.argmax(q_values, dim=1).cpu().numpy() # TRY NOT TO MODIFY: execute the game and log data. - next_obs, rewards, terminated, truncated, infos = envs.step(actions) + next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: @@ -394,10 +394,10 @@ def kl_divergence_with_logits(target_logits, prediction_logits): # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` real_next_obs = next_obs.copy() - for idx, d in enumerate(truncated): - if d: + for idx, trunc in enumerate(truncations): + if trunc: real_next_obs[idx] = infos["final_observation"][idx] - rb.add(obs, real_next_obs, actions, rewards, terminated, infos) + rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs diff --git a/cleanrl/rpo_continuous_action.py b/cleanrl/rpo_continuous_action.py index f7338ee7c..919ee72ae 100644 --- a/cleanrl/rpo_continuous_action.py +++ b/cleanrl/rpo_continuous_action.py @@ -216,8 +216,8 @@ def get_action_and_value(self, x, action=None): logprobs[step] = logprob # TRY NOT TO MODIFY: execute the game and log data. - next_obs, reward, terminated, truncated, infos = envs.step(action.cpu().numpy()) - done = np.logical_or(terminated, truncated) + next_obs, reward, terminations, truncations, infos = envs.step(action.cpu().numpy()) + done = np.logical_or(terminations, truncations) rewards[step] = torch.tensor(reward).to(device).view(-1) next_obs, next_done = torch.Tensor(next_obs).to(device), torch.Tensor(done).to(device) diff --git a/cleanrl/sac_atari.py b/cleanrl/sac_atari.py index 07fa8e53c..f7f4ccb99 100644 --- a/cleanrl/sac_atari.py +++ b/cleanrl/sac_atari.py @@ -180,10 +180,10 @@ def get_action(self, x): if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: -poetry run pip install "stable_baselines3==2.0.0a1" + +poetry run pip install "stable_baselines3==2.0.0a1" "gymnasium[atari,accept-rom-license]==0.28.1" "ale-py==0.8.1" """ ) - args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: From d00d87c06d48599f52b92955b8d12efa8da76020 Mon Sep 17 00:00:00 2001 From: Adam Zhao Date: Sun, 15 Oct 2023 13:00:35 +0800 Subject: [PATCH 3/3] update ppo.py --- cleanrl/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cleanrl/ppo.py b/cleanrl/ppo.py index 506173175..091378209 100644 --- a/cleanrl/ppo.py +++ b/cleanrl/ppo.py @@ -79,7 +79,7 @@ def parse_args(): def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: - env = gym.make(env_id, render_mode="rgb_array") + env = gym.make(env_id) env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id)