From 3f0cd76d47091aa73abc1636fdedf420d8c10e26 Mon Sep 17 00:00:00 2001 From: Noah Farr Date: Wed, 28 Aug 2024 20:04:16 +0200 Subject: [PATCH] Run pre-commit --- cleanrl/td3_continuous_action.py | 87 +++++++++----------------------- 1 file changed, 23 insertions(+), 64 deletions(-) diff --git a/cleanrl/td3_continuous_action.py b/cleanrl/td3_continuous_action.py index 80c2e469..83285391 100644 --- a/cleanrl/td3_continuous_action.py +++ b/cleanrl/td3_continuous_action.py @@ -88,8 +88,7 @@ class QNetwork(nn.Module): def __init__(self, env): super().__init__() self.fc1 = nn.Linear( - np.array(env.single_observation_space.shape).prod() - + np.prod(env.single_action_space.shape), + np.array(env.single_observation_space.shape).prod() + np.prod(env.single_action_space.shape), 256, ) self.fc2 = nn.Linear(256, 256) @@ -159,8 +158,7 @@ def forward(self, x): writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", - "|param|value|\n|-|-|\n%s" - % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding @@ -173,14 +171,9 @@ def forward(self, x): # env setup envs = gym.vector.SyncVectorEnv( - [ - make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) - for i in range(args.num_envs) - ] + [make_env(args.env_id, args.seed + i, i, args.capture_video, run_name) for i in range(args.num_envs)] ) - assert isinstance( - envs.single_action_space, gym.spaces.Box - ), "only continuous action space is supported" + assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" actor = Actor(envs).to(device) qf1 = QNetwork(envs).to(device) @@ -191,9 +184,7 @@ def forward(self, x): target_actor.load_state_dict(actor.state_dict()) qf1_target.load_state_dict(qf1.state_dict()) qf2_target.load_state_dict(qf2.state_dict()) - q_optimizer = optim.Adam( - list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate - ) + q_optimizer = optim.Adam(list(qf1.parameters()) + list(qf2.parameters()), lr=args.learning_rate) actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate) envs.single_observation_space.dtype = np.float32 @@ -212,18 +203,12 @@ def forward(self, x): for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: - actions = np.array( - [envs.single_action_space.sample() for _ in range(envs.num_envs)] - ) + actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: with torch.no_grad(): actions = actor(torch.Tensor(obs).to(device)) actions += torch.normal(0, actor.action_scale * args.exploration_noise) - actions = ( - actions.cpu() - .numpy() - .clip(envs.single_action_space.low, envs.single_action_space.high) - ) + actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) @@ -232,15 +217,9 @@ def forward(self, x): if "final_info" in infos: for info in infos["final_info"]: if info is not None: - print( - f"global_step={global_step}, episodic_return={info['episode']['r']}" - ) - writer.add_scalar( - "charts/episodic_return", info["episode"]["r"], global_step - ) - writer.add_scalar( - "charts/episodic_length", info["episode"]["l"], global_step - ) + print(f"global_step={global_step}, episodic_return={info['episode']['r']}") + writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) + writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation` @@ -257,21 +236,17 @@ def forward(self, x): if global_step > args.learning_starts: data = rb.sample(args.batch_size) with torch.no_grad(): - clipped_noise = ( - torch.randn_like(data.actions, device=device) * args.policy_noise - ).clamp(-args.noise_clip, args.noise_clip) * target_actor.action_scale + clipped_noise = (torch.randn_like(data.actions, device=device) * args.policy_noise).clamp( + -args.noise_clip, args.noise_clip + ) * target_actor.action_scale - next_state_actions = ( - target_actor(data.next_observations) + clipped_noise - ).clamp( + next_state_actions = (target_actor(data.next_observations) + clipped_noise).clamp( envs.single_action_space.low[0], envs.single_action_space.high[0] ) qf1_next_target = qf1_target(data.next_observations, next_state_actions) qf2_next_target = qf2_target(data.next_observations, next_state_actions) min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - next_q_value = data.rewards.flatten() + ( - 1 - data.dones.flatten() - ) * args.gamma * (min_qf_next_target).view(-1) + next_q_value = data.rewards.flatten() + (1 - data.dones.flatten()) * args.gamma * (min_qf_next_target).view(-1) qf1_a_values = qf1(data.observations, data.actions).view(-1) qf2_a_values = qf2(data.observations, data.actions).view(-1) @@ -291,32 +266,16 @@ def forward(self, x): actor_optimizer.step() # update the target network - for param, target_param in zip( - actor.parameters(), target_actor.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) - for param, target_param in zip( - qf1.parameters(), qf1_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) - for param, target_param in zip( - qf2.parameters(), qf2_target.parameters() - ): - target_param.data.copy_( - args.tau * param.data + (1 - args.tau) * target_param.data - ) + for param, target_param in zip(actor.parameters(), target_actor.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf1.parameters(), qf1_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) + for param, target_param in zip(qf2.parameters(), qf2_target.parameters()): + target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data) if global_step % 100 == 0: - writer.add_scalar( - "losses/qf1_values", qf1_a_values.mean().item(), global_step - ) - writer.add_scalar( - "losses/qf2_values", qf2_a_values.mean().item(), global_step - ) + writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step) + writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step) writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step) writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)