Skip to content

Commit

Permalink
Reducing Code Redundancy and Minimize Line Diffs (#422)
Browse files Browse the repository at this point in the history
* clean-up

* update next_obs, rewards, terminations, truncations, infos

* update ppo.py
  • Loading branch information
sdpkjc authored Oct 15, 2023
1 parent d8d4ebf commit cf20043
Show file tree
Hide file tree
Showing 18 changed files with 82 additions and 78 deletions.
11 changes: 6 additions & 5 deletions cleanrl/c51.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -208,13 +208,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -257,7 +258,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss.backward()
optimizer.step()

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

Expand Down
14 changes: 8 additions & 6 deletions cleanrl/c51_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -104,8 +105,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -218,7 +219,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = actions.cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -230,13 +231,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down Expand Up @@ -279,7 +281,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
loss.backward()
optimizer.step()

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
target_network.load_state_dict(q_network.state_dict())

Expand Down
14 changes: 8 additions & 6 deletions cleanrl/c51_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -107,8 +108,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -278,7 +279,7 @@ def get_action(q_state, obs):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -290,13 +291,14 @@ def get_action(q_state, obs):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -319,7 +321,7 @@ def get_action(q_state, obs):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

Expand Down
11 changes: 6 additions & 5 deletions cleanrl/c51_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def loss(q_params, observations, actions, target_pmfs):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -254,13 +254,14 @@ def loss(q_params, observations, actions, target_pmfs):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand All @@ -283,7 +284,7 @@ def loss(q_params, observations, actions, target_pmfs):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

# update the target network
# update target network
if global_step % args.target_network_frequency == 0:
q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1))

Expand Down
3 changes: 1 addition & 2 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def forward(self, x):
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
video_filenames = set()

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
Expand Down Expand Up @@ -205,7 +204,7 @@ def forward(self, x):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
Expand Down
3 changes: 1 addition & 2 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ class TrainState(TrainState):
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
video_filenames = set()

# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
Expand Down Expand Up @@ -258,7 +257,7 @@ def actor_loss(params):
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break

# TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation`
# TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
Expand Down
9 changes: 5 additions & 4 deletions cleanrl/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -195,13 +195,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
12 changes: 7 additions & 5 deletions cleanrl/dqn_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -101,8 +102,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -205,7 +206,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
actions = torch.argmax(q_values, dim=1).cpu().numpy()

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -217,13 +218,14 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
12 changes: 7 additions & 5 deletions cleanrl/dqn_atari_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def thunk():
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)

env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
Expand All @@ -103,8 +104,8 @@ def thunk():
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

env.action_space.seed(seed)
return env

return thunk
Expand Down Expand Up @@ -236,7 +237,7 @@ def mse_loss(params):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -248,13 +249,14 @@ def mse_loss(params):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
11 changes: 5 additions & 6 deletions cleanrl/dqn_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

obs, _ = envs.reset(seed=args.seed)

q_network = QNetwork(action_dim=envs.single_action_space.n)

q_state = TrainState.create(
apply_fn=q_network.apply,
params=q_network.init(q_key, obs),
Expand Down Expand Up @@ -208,7 +206,7 @@ def mse_loss(params):
actions = jax.device_get(actions)

# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminated, truncated, infos = envs.step(actions)
next_obs, rewards, terminations, truncations, infos = envs.step(actions)

# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
Expand All @@ -220,13 +218,14 @@ def mse_loss(params):
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
writer.add_scalar("charts/epsilon", epsilon, global_step)
break

# TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, d in enumerate(truncated):
if d:
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminated, infos)
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
Expand Down
9 changes: 5 additions & 4 deletions cleanrl/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,12 @@ def parse_args():

def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
env = gym.make(env_id)
if capture_video and idx == 0:
env = gym.make(env_id)
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
if capture_video:
if idx == 0:
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
env.seed(seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
Expand Down
Loading

1 comment on commit cf20043

@vercel
Copy link

@vercel vercel bot commented on cf20043 Oct 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

cleanrl – ./

cleanrl-git-master-vwxyzjn.vercel.app
cleanrl-vwxyzjn.vercel.app
cleanrl.vercel.app
docs.cleanrl.dev

Please sign in to comment.